-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Allow any single non-batch dim to be ragged for NJT #137125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137125
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 29d5faa with merge base a2bc2e3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly, which is useful at the very least for testing on transposed NJTs. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this. At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed. [ghstack-poisoned]
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this. At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed. [ghstack-poisoned]
|
@pytorchbot merge |
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. Pull Request resolved: #141392 Approved by: https://github.com/cpuhrsch ghstack dependencies: #140736, #140161
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. [ghstack-poisoned]
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. Pull Request resolved: #141392 Approved by: https://github.com/cpuhrsch ghstack dependencies: #141500, #140736, #140161
Fixes pytorch#137512 Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this. At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed. Pull Request resolved: pytorch#137125 Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after pytorch#137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. Pull Request resolved: pytorch#141392 Approved by: https://github.com/cpuhrsch ghstack dependencies: pytorch#140736, pytorch#140161
This PR contains three `unsqueeze()`-related fixes for NJT: 1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim 2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly 3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after pytorch#137125 Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this. Pull Request resolved: pytorch#141392 Approved by: https://github.com/cpuhrsch ghstack dependencies: pytorch#141500, pytorch#140736, pytorch#140161
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch dims should map to this dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. [ghstack-poisoned]
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch dims should map to this dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. [ghstack-poisoned]
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. Pull Request resolved: #142173 Approved by: https://github.com/drisspg
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after pytorch#137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. Pull Request resolved: pytorch#142173 Approved by: https://github.com/drisspg
ghstack-source-id: 08ec06f Pull Request resolved: pytorch/pytorch#137125
Stack from ghstack (oldest at bottom):
Fixes #137512
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g.
(B, *, D_0, ..., D_N). This allows for constructing NJTs of shape e.g.(B, D, j0)directly. It's possible before this PR to get an NJT of e.g. shape(B, D, j0)by constructing an NJT of shape(B, j0, D)and transposing it. This PR allows a user to go straight there without the transpose. The standardtorch.nested.nested_tensor(list)constructor has been updated to support this.At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.