-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DTensor] dispatch to sharding prop over decomps #159324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes #159110 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159324
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit d175ad9 with merge base 1abff80 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fixes #159110 cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta [ghstack-poisoned]
|
Does that mean this needs to be updated if rms_norm will now go to the fused path?
CC @AaronWang04 |
|
@eqy I didn't merge in sharding rule for forward pass of rms_norm since it never got triggered. Will add a PR for that and update the test after this gets merged |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@AaronWang04 |
|
@tianyu-l I planned to add the forward op sharding strategy once this PR is stable. why would this change break torchtitan? I expect this change to be able to fall back on composite |
Fixes #159110 Pull Request resolved: #159324 Approved by: https://github.com/ezyang
Reduces collective calls in the forward pass from 2 to 1 In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: #159692 Approved by: https://github.com/tianyu-l
Reduces collective calls in the forward pass from 2 to 1 In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: #159692 Approved by: https://github.com/tianyu-l
Reduces collective calls in the forward pass from 2 to 1 In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: #159692 Approved by: https://github.com/tianyu-l
Reduces collective calls in the forward pass from 2 to 1 In pytorch#158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After pytorch#159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: pytorch#159692 Approved by: https://github.com/tianyu-l
Reduces collective calls in the forward pass from 2 to 1 In pytorch#158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After pytorch#159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: pytorch#159692 Approved by: https://github.com/tianyu-l
Stack from ghstack (oldest at bottom):
Fixes #159110
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta