KEMBAR78
Making batching rule for F.embedding DTensor-aware by zou3519 · Pull Request #162117 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Sep 4, 2025

Stack from ghstack (oldest at bottom):

vmap(F.embedding)(DTensor, DTensor) was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:

  • new test

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim

`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
every non-randomness batching rule. This is safe to do, because batching
rules must return tensors of the same shape and factory functions will
not return tensors of different values.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162117

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0bb02bb with merge base f4c33cd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 4, 2025
zou3519 added a commit that referenced this pull request Sep 4, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
every non-randomness batching rule. This is safe to do, because batching
rules must return tensors of the same shape and factory functions will
not return tensors of different values.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

ghstack-source-id: 911c4af
Pull Request resolved: #162117
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
every non-randomness batching rule. This is safe to do, because batching
rules must return tensors of the same shape and factory functions will
not return tensors of different values.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Sep 4, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
every non-randomness batching rule. This is safe to do, because batching
rules must return tensors of the same shape and factory functions will
not return tensors of different values.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

ghstack-source-id: 2ad3045
Pull Request resolved: #162117
@zou3519 zou3519 requested review from bdhirsh and tianyu-l September 4, 2025 13:29
@zou3519 zou3519 changed the title Batching rules assume DTensor implicit replication Making batching rule for F.embedding DTensor-aware Sep 4, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Sep 4, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

ghstack-source-id: 58b403f
Pull Request resolved: #162117
@zou3519 zou3519 requested review from bdhirsh and ezyang September 5, 2025 13:27
@zou3519
Copy link
Contributor Author

zou3519 commented Sep 5, 2025

Updated to a more local workaround since the previous approach didn't work

"aten/src/ATen/DeviceAccelerator.cpp",
"aten/src/ATen/Context.cpp",
"aten/src/ATen/DLConvertor.cpp",
"aten/src/ATen/DTensorState.cpp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh god non-globbed ATen file sources

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm!

@zou3519 zou3519 added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (dtensor) release notes category labels Sep 5, 2025
@zou3519
Copy link
Contributor Author

zou3519 commented Sep 5, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

daisyden pushed a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: pytorch#162117
Approved by: https://github.com/bdhirsh
pytorchmergebot pushed a commit that referenced this pull request Sep 8, 2025
F.one_hot(dtensor) used to run into a mixed DTensor-Tensor operation due
to an arange call creating a new Tensor (not DTensor). This PR fixes it
by allowing implicit replication of Tensors for the arange call and the
one consumer of the arange call (the at::eq call).

Test Plan:
- new test. Also, F.one_hot(num_classes=-1) is broken so we skip that.

Pull Request resolved: #162307
Approved by: https://github.com/ezyang
ghstack dependencies: #162117
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: pytorch#162117
Approved by: https://github.com/bdhirsh
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
F.one_hot(dtensor) used to run into a mixed DTensor-Tensor operation due
to an arange call creating a new Tensor (not DTensor). This PR fixes it
by allowing implicit replication of Tensors for the arange call and the
one consumer of the arange call (the at::eq call).

Test Plan:
- new test. Also, F.one_hot(num_classes=-1) is broken so we skip that.

Pull Request resolved: pytorch#162307
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117
pytorchmergebot pushed a commit that referenced this pull request Sep 18, 2025
…or operations (#162651)

Also updates the error message to point to the guide.

Pull Request resolved: #162651
Approved by: https://github.com/ezyang
ghstack dependencies: #162117, #162307
pytorchmergebot pushed a commit that referenced this pull request Sep 19, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: #163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: #162117, #162307, #162651
zou3519 added a commit that referenced this pull request Sep 20, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: #163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: #162117, #162307, #162651
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: pytorch#162117
Approved by: https://github.com/bdhirsh
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
F.one_hot(dtensor) used to run into a mixed DTensor-Tensor operation due
to an arange call creating a new Tensor (not DTensor). This PR fixes it
by allowing implicit replication of Tensors for the arange call and the
one consumer of the arange call (the at::eq call).

Test Plan:
- new test. Also, F.one_hot(num_classes=-1) is broken so we skip that.

Pull Request resolved: pytorch#162307
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…or operations (pytorch#162651)

Also updates the error message to point to the guide.

Pull Request resolved: pytorch#162651
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117, pytorch#162307
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: pytorch#163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: pytorch#162117, pytorch#162307, pytorch#162651
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: pytorch#162117
Approved by: https://github.com/bdhirsh
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
F.one_hot(dtensor) used to run into a mixed DTensor-Tensor operation due
to an arange call creating a new Tensor (not DTensor). This PR fixes it
by allowing implicit replication of Tensors for the arange call and the
one consumer of the arange call (the at::eq call).

Test Plan:
- new test. Also, F.one_hot(num_classes=-1) is broken so we skip that.

Pull Request resolved: pytorch#162307
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…or operations (pytorch#162651)

Also updates the error message to point to the guide.

Pull Request resolved: pytorch#162651
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117, pytorch#162307
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: pytorch#163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: pytorch#162117, pytorch#162307, pytorch#162651
huydhn pushed a commit that referenced this pull request Sep 23, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: #163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: #162117, #162307, #162651
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: pytorch#162117
Approved by: https://github.com/bdhirsh
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
F.one_hot(dtensor) used to run into a mixed DTensor-Tensor operation due
to an arange call creating a new Tensor (not DTensor). This PR fixes it
by allowing implicit replication of Tensors for the arange call and the
one consumer of the arange call (the at::eq call).

Test Plan:
- new test. Also, F.one_hot(num_classes=-1) is broken so we skip that.

Pull Request resolved: pytorch#162307
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…or operations (pytorch#162651)

Also updates the error message to point to the guide.

Pull Request resolved: pytorch#162651
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162117, pytorch#162307
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: pytorch#163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: pytorch#162117, pytorch#162307, pytorch#162651
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants