KEMBAR78
Forward / backward NJT support for several activation functions by jbschlosser · Pull Request #140736 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Nov 14, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140736

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit a9aed7d with merge base efec302 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser jbschlosser added topic: improvements topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Nov 14, 2024
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

python_module: nn
tags: pointwise

- func: relu6_(Tensor(a!) self) -> Tensor(a!)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inplace versions should also be marked as pointwise, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started with this but I realized it's not quite right since ideally the same NestedTensor object should be returned for in-place ops. We might need to update jagged_unary_pointwise() to do this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but lots of other inplace ops are already tagged as pointwise unfortunately.

Copy link
Contributor Author

@jbschlosser jbschlosser Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm true, I think I'll open a follow-up to handle this properly within NJT, since we can't remove pointwise from those inplace ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, this actually already works today:

def test_inplace_unary_op_returns_same_instance(self, device, dtype):
    nt = random_nt_from_dims(
        [3, None, 5],
        device=device,
        dtype=dtype,
        layout=torch.jagged,
    )
    out = nt.relu_()
    # passes!
    self.assertIs(out, nt)

What I believe happens is that the ADInplaceOrView key handles aliasing.

…tions"


Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.

[ghstack-poisoned]
…tions"


Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.

[ghstack-poisoned]
…tions"


Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.

[ghstack-poisoned]
…tions"


Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 22, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Nov 22, 2024

@pytorchbot revert -m 'Sorry for reverting your change but its tests are failing in trunk' -c nosignal

test_nestedtensor.py::TestNestedTensorOpInfoCUDA::test_compile_backward_nn_functional_softshrink_cuda_float32 GH job link HUD commit link

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot pushed a commit that referenced this pull request Nov 26, 2024
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: #140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
pytorchmergebot pushed a commit that referenced this pull request Nov 26, 2024
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: #141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
cyyever pushed a commit to cyyever/pytorch that referenced this pull request Nov 27, 2024
)

This fixes some bugs when performing reductions / select() on dims before the ragged dim. In this case, the output NJT has a smaller number of dims, and its ragged_idx should reflect that correctly.
Pull Request resolved: pytorch#141506
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: pytorch#141500, pytorch#140736, pytorch#140161, pytorch#141392
pytorchmergebot pushed a commit that referenced this pull request Nov 27, 2024
…141604)

Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: #141604
Approved by: https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392, #141506
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#140736)

Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: pytorch#140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ns (pytorch#140736)"

This reverts commit af70f5e.

Reverted pytorch#140736 on behalf of https://github.com/huydhn due to Sorry for reverting your change but its tests are failing in trunk ([comment](pytorch#140736 (comment)))
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#140736)

Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: pytorch#140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: pytorch#140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: pytorch#140736
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after pytorch#137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: pytorch#141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#140736, pytorch#140161
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This reverts commit 48409a5.

Reverted pytorch#141392 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](pytorch#140736 (comment)))
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This reverts commit 730caf0.

Reverted pytorch#140161 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](pytorch#140736 (comment)))
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ns (pytorch#140736)"

This reverts commit daaecb9.

Reverted pytorch#140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](pytorch#140736 (comment)))
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](https://github.com/pytorch/pytorch/blob/78491d6afc163d1d84e81c015fad695caa8ec98a/tools/autograd/derivatives.yaml#L432-L434)). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of pytorch#140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: pytorch#141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#140736)

Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: pytorch#140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: pytorch#141500
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: pytorch#140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: pytorch#141500, pytorch#140736
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after pytorch#137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: pytorch#141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#141500, pytorch#140736, pytorch#140161
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
)

This fixes some bugs when performing reductions / select() on dims before the ragged dim. In this case, the output NJT has a smaller number of dims, and its ragged_idx should reflect that correctly.
Pull Request resolved: pytorch#141506
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: pytorch#141500, pytorch#140736, pytorch#140161, pytorch#141392
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ytorch#141604)

Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: pytorch#141604
Approved by: https://github.com/soulitzer
ghstack dependencies: pytorch#141500, pytorch#140736, pytorch#140161, pytorch#141392, pytorch#141506
pytorchmergebot pushed a commit that referenced this pull request Dec 6, 2024
#140736 fixed some xfails, but these were not properly failing in CI due to #142157. This PR removes the xfails so we can land a fix to that issue.
Pull Request resolved: #142243
Approved by: https://github.com/huydhn
pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
#140736 fixed some xfails, but these were not properly failing in CI due to #142157. This PR removes the xfails so we can land a fix to that issue.
Pull Request resolved: #142243
Approved by: https://github.com/huydhn
@github-actions github-actions bot deleted the gh/jbschlosser/201/head branch December 27, 2024 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors Reverted topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants