KEMBAR78
Fix sum() forward for NJT by jbschlosser · Pull Request #131945 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Jul 26, 2024

Stack from ghstack (oldest at bottom):

This PR solves two problems with sum() support in NJT:

  • sum() over a dim with keepdim=True returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in [NT] Backward support for broadcasting binary ops #112519.
  • Historically, we've only supported sum() over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131945

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b9ae9c5 with merge base b6d6aa4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jbschlosser added a commit that referenced this pull request Jul 26, 2024
ghstack-source-id: 6496092
Pull Request resolved: #131945
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Copy link
Contributor

@jananisriram jananisriram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!! thanks for adding this support :)

This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 30, 2024
ghstack-source-id: c8ee29d
Pull Request resolved: #131945
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 31, 2024
ghstack-source-id: cfdd3f3
Pull Request resolved: #131945
@jbschlosser
Copy link
Contributor Author

jbschlosser commented Aug 6, 2024

The reason I haven't landed this yet is because this uncovered torch.compile limitations with calculating the max seqlen metadata on the fly, a calculation that can be required for any ops that utilize the jagged->padded, computation, padded->jagged fallback approach.

This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 8, 2024
ghstack-source-id: 654a73a
Pull Request resolved: #131945
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 13, 2024
ghstack-source-id: 1ade418
Pull Request resolved: #131945
@jbschlosser jbschlosser added module: nestedtensor NestedTensor tag see issue #25032 topic: bug fixes topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Sep 13, 2024
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 13, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: pytorch#131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
@github-actions github-actions bot deleted the gh/jbschlosser/168/head branch October 14, 2024 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: nestedtensor NestedTensor tag see issue #25032 release notes: nested tensor Changes that have a direct impact on nested tensors topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants