-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[dtensor] Rework partial propagation in pointwise op and support mul #157340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I am trying to see if I can easily add the linearity support for aten.mul to allow Partial placement to propagate through. But it turns out that I have to completely rework the current linearity propagation. In short, before this PR, linearity mainly support aten.add and some trival ops. It is done by allowing input Partial to propagate, and in the meanwhile, redistribute Replicate inputs to Partial to preserve the single device semantic, i.e suppose we have aten.add(lhs, rhs) on 2 ranks: * lhs is partial, value on rank 0: r0, lhs value on rank 1: r1 * rhs is replicate, value: a Then in order to perserve single device semantic (which should produce the value of `a + r0 + r1`), we do rhs/world_size first, then add rhs to lhs. This means every operand would first be partial, then we can add them together. But this become non-true for multiplicative operations, like aten.mul, for aten.mul, assuming the same `aten.mul(lhs, rhs)` and value, we don't need to divide lhs by world_size to preserve single device semantic, b.c. `a* (r0+r1) = a* r0 + a* r1` So to accomodate the difference of add/mul, in this PR I: * change linearity to be a int to support different linearity types, add linearity and multiplicative are separate * add checks to ensure only a subset of partial types can support linearity (namely partial-sum/avg) * handle the linearity type plumbing through the pointwise ops. * add mul.Tensor/Scalar to be the multiplicative linearity * added the tests to show that the partial placements can be propagated with aten.mul
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157340
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 326eeea with merge base 81759af ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I am trying to see if I can easily add the linearity support for aten.mul to allow Partial placement to propagate through. But it turns out that I have to completely rework the current linearity propagation.
In short, before this PR, linearity mainly support aten.add and some trival ops. It is done by allowing input Partial to propagate, and in the meanwhile, redistribute Replicate inputs to Partial to preserve the single device semantic, i.e suppose we want to execute
aten.add(lhs, rhs)on 2 ranks:lhsis partial, value on rank 0:r0, lhs value on rank 1:r1rhsis replicate, value:aThen in order to preserve single device semantic (which should produce the value of
a + r0 + r1), we dorhs/world_sizefirst, then addrhstolhs. This means every operand would first need be partial, then we can add them together.But this become non-true for multiplicative operations, like
aten.mul, foraten.mul, assuming the sameaten.mul(lhs, rhs)and value, we don't need to divide lhs by world_size to preserve single device semantic, b.c.a* (r0+r1) = a* r0 + a* r1So to accomodate the difference of add/mul, in this PR I:
mul.Tensor/Scalarto be the multiplicative linearityaten.mulFixes #ISSUE_NUMBER
cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k