-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Support rms_norm() for NJT #135872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support rms_norm() for NJT #135872
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135872
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2202eb5 with merge base bc1b8f0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
`rms_norm()` is a nice-to-have for ViT :) [ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :) This PR: * Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT [ghstack-poisoned]
|
hopefully, PyTorch can include in future a torch.compile'd version of RMSNorm by default... (e.g. could be in nn.py: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes to skip calculating M and N in rms_norm input validation lgtm
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral) Details for Dev Infra teamRaised by workflow job |
ghstack-source-id: 5eb19b3 Pull Request resolved: pytorch/pytorch#135872
`rms_norm()` is a nice-to-have for ViT :) This PR: * Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT [ghstack-poisoned]
|
@mikaylagawarecki sorry I need another pass - turns out my original approach to reuse the decomp for NJT had problems. New approach SymInt-ifies |
`rms_norm()` is a nice-to-have for ViT :) This PR: * Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp. * Adds torch_function-based input validation logic on the python NJT side * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT [ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. [ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. [ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the extra effort symintifying the impl and removing the dependency on check_layer_norm_inputs
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
`rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. Pull Request resolved: pytorch#135872 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: pytorch#125947
Stack from ghstack (oldest at bottom):
rms_norm()is a nice-to-have for ViT :)This PR:
rms_norm(), allowing NJT to use the same decomp.mean()for NJT.