KEMBAR78
[PP] Allow intermediate nodes in ZB to have multiple grads by H-Huang · Pull Request #159084 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Jul 24, 2025

Stack from ghstack (oldest at bottom):

Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792)

Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values (grad, None) (see #153666) and it started breaking our ZB tests.

This PR allows stage_backward_weight intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2):

class Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x, 2
    @staticmethod
    def backward(ctx, gI1, gI2):
        assert gI2 is None
        return gI1

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159084

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit cc48dfd with merge base 70b4a88 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 24, 2025
H-Huang added a commit that referenced this pull request Jul 24, 2025
weight_grads.append(weight.grad)

for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR fixes this TODO

@H-Huang H-Huang added pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html release notes: distributed (pipeline) release notes category labels Jul 24, 2025
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792)

Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values `(grad, None)`. I am not sure why this was introduced but it started breaking our ZB tests.

This PR allows `stage_backward_weight` intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2):

```python
class Func(torch.autograd.Function):
    staticmethod
    def forward(ctx, x):
        return x, 2
    staticmethod
    def backward(ctx, gI1, gI2):
        assert gI2 is None
        return gI1
```

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 24, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regression came from recent support of fused RMSNorm kernels #153666

I know that in the past it was going through some decomposition at aten level, but I didn't pay attention to why we are seeing more fields passed around.

If this fix is a general improvement, rather than specifically tailors to this regression, then sounds good to me.

@H-Huang
Copy link
Member Author

H-Huang commented Jul 25, 2025

The regression came from recent support of fused RMSNorm kernels #153666

Thanks for providing this context.

Yep, this fix is a general improvement, that change just happened to surface it!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock

@H-Huang H-Huang added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 25, 2025
@AaronWang04
Copy link
Contributor

The regression came from recent support of fused RMSNorm kernels #153666

I know that in the past it was going through some decomposition at aten level, but I didn't pay attention to why we are seeing more fields passed around.

If this fix is a general improvement, rather than specifically tailors to this regression, then sounds good to me.

Just to add some context, an unfused composite implementation relies on autograd to save intermediate values for backprop

In a fused implementation, autograd doesn't have context as to what happens within the fused kernel, thus we have to return some intermediate values so that they can be used and do not have to be recomputed during the backward pass

@H-Huang
Copy link
Member Author

H-Huang commented Jul 27, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

H-Huang added a commit to pytorch/torchtitan that referenced this pull request Jul 29, 2025
These should be fixed now that
pytorch/pytorch#159084 has landed
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792)

Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values `(grad, None)` (see #153666) and it started breaking our ZB tests.

This PR allows `stage_backward_weight` intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2):

```python
class Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x, 2
    @staticmethod
    def backward(ctx, gI1, gI2):
        assert gI2 is None
        return gI1
```

Pull Request resolved: #159084
Approved by: https://github.com/tianyu-l
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
@github-actions github-actions bot deleted the gh/H-Huang/198/head branch August 27, 2025 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants