-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[PP] Allow intermediate nodes in ZB to have multiple grads #159084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159084
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit cc48dfd with merge base 70b4a88 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| weight_grads.append(weight.grad) | ||
|
|
||
| for param_group in param_groups: | ||
| # TODO: Handle case where intermediate can have multiple outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR fixes this TODO
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792) Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values `(grad, None)`. I am not sure why this was introduced but it started breaking our ZB tests. This PR allows `stage_backward_weight` intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2): ```python class Func(torch.autograd.Function): staticmethod def forward(ctx, x): return x, 2 staticmethod def backward(ctx, gI1, gI2): assert gI2 is None return gI1 ``` cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The regression came from recent support of fused RMSNorm kernels #153666
I know that in the past it was going through some decomposition at aten level, but I didn't pay attention to why we are seeing more fields passed around.
If this fix is a general improvement, rather than specifically tailors to this regression, then sounds good to me.
Thanks for providing this context. Yep, this fix is a general improvement, that change just happened to surface it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to unblock
Just to add some context, an unfused composite implementation relies on autograd to save intermediate values for backprop In a fused implementation, autograd doesn't have context as to what happens within the fused kernel, thus we have to return some intermediate values so that they can be used and do not have to be recomputed during the backward pass |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
These should be fixed now that pytorch/pytorch#159084 has landed
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792) Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values `(grad, None)` (see #153666) and it started breaking our ZB tests. This PR allows `stage_backward_weight` intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2): ```python class Func(torch.autograd.Function): @staticmethod def forward(ctx, x): return x, 2 @staticmethod def backward(ctx, gI1, gI2): assert gI2 is None return gI1 ``` Pull Request resolved: #159084 Approved by: https://github.com/tianyu-l
These should be fixed now that pytorch/pytorch#159084 has landed
These should be fixed now that pytorch/pytorch#159084 has landed
These should be fixed now that pytorch/pytorch#159084 has landed
Stack from ghstack (oldest at bottom):
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792)
Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values
(grad, None)(see #153666) and it started breaking our ZB tests.This PR allows
stage_backward_weightintermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2):cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta