-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix sum() forward for NJT #131945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sum() forward for NJT #131945
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131945
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b9ae9c5 with merge base b6d6aa4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!! thanks for adding this support :)
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
|
The reason I haven't landed this yet is because this uncovered torch.compile limitations with calculating the max seqlen metadata on the fly, a calculation that can be required for any ops that utilize the |
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). Pull Request resolved: pytorch#131945 Approved by: https://github.com/davidberard98, https://github.com/jananisriram
Stack from ghstack (oldest at bottom):
This PR solves two problems with
sum()support in NJT:sum()over a dim withkeepdim=Truereturns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in [NT] Backward support for broadcasting binary ops #112519.sum()over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ