KEMBAR78
Support rms_norm() for NJT by jbschlosser · Pull Request #135872 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Sep 12, 2024

Stack from ghstack (oldest at bottom):

rms_norm() is a nice-to-have for ViT :)

This PR:

  • SymInt-ifies rms_norm(), allowing NJT to use the same decomp.
  • Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
  • Adds multi-dim support (on non-ragged, non-batch dims) to mean() for NJT.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135872

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2202eb5 with merge base bc1b8f0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

`rms_norm()` is a nice-to-have for ViT :)

[ghstack-poisoned]
@jbschlosser jbschlosser added topic: improvements topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Sep 12, 2024
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT

[ghstack-poisoned]
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 12, 2024

hopefully, PyTorch can include in future a torch.compile'd version of RMSNorm by default... (e.g. could be in nn.py: RMSNorm = torch.compile(RMSNorm), just needs to be sure that the dynamo/inductor options are kept fixed and not propagated from outside, or maybe generated triton/c++ kernels are pre-generated at PyTorch building stage) as demonstrated by Liger, many users of HF still use eager mode... so adding to core torch.compile'd shortcuts like RMSNorm can let these legacy models / legacy users still benefit from the newly generated fused kernels

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes to skip calculating M and N in rms_norm input validation lgtm

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 13, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral)

Details for Dev Infra team Raised by workflow job

enter-ctrl9 pushed a commit to enter-ctrl9/pytorch11 that referenced this pull request Sep 15, 2024
ghstack-source-id: 5eb19b3
Pull Request resolved: pytorch/pytorch#135872
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

jbschlosser commented Sep 16, 2024

@mikaylagawarecki sorry I need another pass - turns out my original approach to reuse the decomp for NJT had problems. New approach SymInt-ifies rms_norm().

`rms_norm()` is a nice-to-have for ViT :)

This PR:
* Tweaks the input validation logic for `rms_norm()` slightly to avoid errors for NJT. This way, we can use the same decomp.
* Adds torch_function-based input validation logic on the python NJT side
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT

[ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.

[ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.

[ghstack-poisoned]
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()` + does input shape validation in symbolic land. This allows NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.

[ghstack-poisoned]
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the extra effort symintifying the impl and removing the dependency on check_layer_norm_inputs

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: pytorch#135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: pytorch#125947
@github-actions github-actions bot deleted the gh/jbschlosser/175/head branch October 18, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants