KEMBAR78
[dtensor] Rework partial propagation in pointwise op and support mul by wanchaol · Pull Request #157340 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jul 1, 2025

I am trying to see if I can easily add the linearity support for aten.mul to allow Partial placement to propagate through. But it turns out that I have to completely rework the current linearity propagation.

In short, before this PR, linearity mainly support aten.add and some trival ops. It is done by allowing input Partial to propagate, and in the meanwhile, redistribute Replicate inputs to Partial to preserve the single device semantic, i.e suppose we want to execute aten.add(lhs, rhs) on 2 ranks:

  • lhs is partial, value on rank 0: r0, lhs value on rank 1: r1
  • rhs is replicate, value: a

Then in order to preserve single device semantic (which should produce the value of a + r0 + r1), we do rhs/world_size first, then add rhs to lhs. This means every operand would first need be partial, then we can add them together.

But this become non-true for multiplicative operations, like aten.mul, for aten.mul, assuming the same aten.mul(lhs, rhs) and value, we don't need to divide lhs by world_size to preserve single device semantic, b.c. a* (r0+r1) = a* r0 + a* r1

So to accomodate the difference of add/mul, in this PR I:

  • change linearity to be a int to support different linearity types, add linearity and multiplicative are separate
  • add checks to ensure only a subset of partial types can support linearity (namely partial-sum/avg)
  • handle the linearity type plumbing through the pointwise ops.
  • add mul.Tensor/Scalar to be the multiplicative linearity
  • added the tests to show that the partial placements can be propagated with aten.mul

Fixes #ISSUE_NUMBER

cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k

I am trying to see if I can easily add the linearity support for
aten.mul to allow Partial placement to propagate through. But it turns
out that I have to completely rework the current linearity propagation.

In short, before this PR, linearity mainly support aten.add and some
trival ops. It is done by allowing input Partial to propagate, and in
the meanwhile, redistribute Replicate inputs to Partial to preserve the
single device semantic, i.e suppose we have aten.add(lhs, rhs) on 2
ranks:
* lhs is partial, value on rank 0: r0, lhs value on rank 1: r1
* rhs is replicate, value: a

Then in order to perserve single device semantic (which should produce
the value of `a + r0 + r1`), we do rhs/world_size first, then add rhs to
lhs. This means every operand would first be partial, then we can add
them together.

But this become non-true for multiplicative operations, like aten.mul,
for aten.mul, assuming the same `aten.mul(lhs, rhs)` and value, we don't
need to divide lhs by world_size to preserve single device semantic,
b.c. `a* (r0+r1) = a* r0 + a* r1`

So to accomodate the difference of add/mul, in this PR I:
* change linearity to be a int to support different linearity types, add
  linearity and multiplicative are separate
* add checks to ensure only a subset of partial types can support
  linearity (namely partial-sum/avg)
* handle the linearity type plumbing through the pointwise ops.
* add mul.Tensor/Scalar to be the multiplicative linearity
* added the tests to show that the partial placements can be propagated
  with aten.mul
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157340

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 326eeea with merge base 81759af (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 1, 2025
@wanchaol wanchaol added the release notes: distributed (dtensor) release notes category label Jul 1, 2025
@wanchaol wanchaol added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 2, 2025
Copy link
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@wanchaol
Copy link
Collaborator Author

wanchaol commented Jul 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the mul_op branch August 3, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants