KEMBAR78
NJT <-> padded dense conversions by jbschlosser · Pull Request #125947 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented May 10, 2024

Stack from ghstack (oldest at bottom):

This PR:

  • Implements the pre-existing nt.to_padded_tensor(padding_val) ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
  • Introduces a new _nested_from_padded_tensor op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    • Note: there is currently no public API for this; design booted to a future PR

TODO:

  • Propagate min / max sequence length via the new factory function _nested_from_padded_tensor
  • Verify that Inductor does computation fusion via test logic

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

@pytorch-bot
Copy link

pytorch-bot bot commented May 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125947

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6b9f037 with merge base bc1b8f0 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser jbschlosser marked this pull request as draft May 10, 2024 18:54
@vadimkantorov
Copy link
Contributor

vadimkantorov commented May 10, 2024

Related old discussion on this being a useful primitive (in the old days for collate_fn of data loading) and deserving more fame :)

One useful thing here is also to support "padding multiples" per dimension

@jbschlosser jbschlosser added the topic: not user facing topic category label May 14, 2024
@vadimkantorov
Copy link
Contributor

Maybe one way to auto-construct NJT from torch.stack([...]) call in default collate could be:

  1. from dataset.__getitem__ return dense tensors but wrapped in NJT (but the internal representation should be just regular dense tensor)
  2. support that torch.stack([..., ]) returns a NJT if elements in the input list are NJT (even if inside they are just dense tensors)

like this the collate_fn code could be kept unchanged, but if the inputs are wrapped as NJT, it would start to produce a NJT...

@huydhn
Copy link
Contributor

huydhn commented Sep 9, 2024

@pytorchbot revert -m 'Sorry for reverting your change but it is failing dynamo test https://hud.pytorch.org/pytorch/pytorch/commit/09a5e88bef04d5485b70d8f65f46a675aaa52942, maybe a landrace' -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@jbschlosser your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 9, 2024
This reverts commit 09a5e88.

Reverted #125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test https://hud.pytorch.org/pytorch/pytorch/commit/09a5e88bef04d5485b70d8f65f46a675aaa52942, maybe a landrace ([comment](#125947 (comment)))
@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2024

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@jbschlosser
Copy link
Contributor Author

Changing the meta registration for _padded_dense_to_jagged_forward() to be a fake tensor impl fixes the failing dynamo test. It must be the latter since it has to create an unbacked SymInt in the case that sum_S is not specified.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 17, 2024
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: #135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: pytorch#125947
Approved by: https://github.com/soulitzer
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This reverts commit 09a5e88.

Reverted pytorch#125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test https://hud.pytorch.org/pytorch/pytorch/commit/09a5e88bef04d5485b70d8f65f46a675aaa52942, maybe a landrace ([comment](pytorch#125947 (comment)))
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: pytorch#125947
Approved by: https://github.com/soulitzer
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: pytorch#135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: pytorch#125947
@github-actions github-actions bot deleted the gh/jbschlosser/140/head branch October 13, 2024 02:09
KnAwnime pushed a commit to KnAwnime/Biblioteka that referenced this pull request Oct 16, 2024
ghstack-source-id: df642b5
Pull Request resolved: pytorch/pytorch#125947
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants