-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[NT] Backward support for broadcasting binary ops #112519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112519
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 08d2e23 with merge base 1855153 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! got some minor comments and questions
torch/nested/_internal/ops.py
Outdated
| # sum_dim_IntList can produce a NT or a T depending on whether the ragged dim | ||
| # is summed over |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm I'm a little wary of this, although it does make sense. I believe this is the first op operating on NTs that can produce a dense T. In general, we should have consistent semantics for ops that reduce out the raggedness, whichever way we decide. @cpuhrsch any opinion on if conditionally returning a dense T here is the way to go?
If we return a T here and the user wants to go back to NT land, a non-copying as_nested_tensor(t) seems useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think returning plain T has more known benefits, and it's not clear if returning NT today would give us more flexibility wrt bc either. If we return NT today and make users to grab values explicitly, that would no longer work if we choose to return T later.
Pros:
- avoid having additional state on NT tracking whether we are a dense NT or not
- less additional logic within autograd to handle the conversion from dense NT to dense
- users would not have to worry about having to manually converting the dense NT back to dense to avoid subclass overhead
Cons:
- some ops on NT unexpectedly(?) return non-NTs. I'm not actually sure how problematic it is. On the NT front, as long as NTs can freely promote, doing any NT related ops should not be problematic on this dense T.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @albanD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From offline discussion @albanD thinks this seems fine as long as there are no silent correctness issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool I'm good with it then :) I'm finishing up an impl of as_nested_tensor(t) so that should help as well.
torch/nested/_internal/ops.py
Outdated
| t_shape = t.shape | ||
| extra = 0 | ||
| for s in t_shape: | ||
| if s == 1: | ||
| if t.dim() > nt.dim(): | ||
| extra += 1 | ||
| t = t.squeeze(0) | ||
| return t, extra |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this all work within torch.compile without unwanted guards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm let me check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we do want the extra guards here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm okay, which guards show up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we do want the extra guards here.
Whoops, this is wrong actually. What is happening is that there ARE always guards testing whether inputs given are zero/one or not (due to zero-one specialization) which is indeed what we want, but also, as a result of that there are also no "extra" guards from this s == 1 test here.
So there are actually two cases: (1) the case where the inputs are also the inputs to the entire program, and (2) the case where the inputs are not.
In (1), what I wrote about zero-one specialization holds. But otherwise, the "one" would've needed to been created somewhere, and likely not a symint, so no extra guards either way.
torch/nested/_internal/ops.py
Outdated
| # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) | ||
| nt, t = (a, b) if a_is_nt else (b, a) | ||
| # See Note: [ Squeezing leading ones ] | ||
| t_squeezed, extra = squeeze_leading_ones_get_extra(t, nt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumb Q: can you explain a bit more the purpose behind this change? I'm unclear as to what it's solving in the context of this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously we were too relaxed when it comes to checking whether the pre-existing broadcasting logic is valid for the given NTs. e.g. someone could try JT: [B, *, D] + T: [sum(*), D] and still go through the eays path and produce something as output, even though broadcasting didn't really make sense in the first place.
What we kind of what to do is to make sure that the JT has a dim that is two greater than that of the other tensor, so that we would have something like (B, j0, a0, ..., an) + (b0, ..., bm) where m < n where B and j0 are always broadcasted over uniformly (as in any value in a given batch acquires the same values) during broadcasting logic. This is important because otherwise we would fall into the unbind case below.
We want this squeeze_leading_ones_get_extra helper because naively doing nt.dim() > t.dim() + 2 doesn't quite work though because the t might have leading ones in its dim that shouldn't disqualify it from the easy case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, thanks for the explanation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if we see t.dim() > nt.dim(), we know we're in an unsupported case, even if t is all ones. If we bail out early for this, it might be a little more readable.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / win-vs2019-cuda11.8-py3 / build Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#112519 Approved by: https://github.com/jbschlosser ghstack dependencies: pytorch#113031
Fixes #112845 Pull Request resolved: #113091 Approved by: https://github.com/jbschlosser ghstack dependencies: #113031, #112519
Pull Request resolved: pytorch#112519 Approved by: https://github.com/jbschlosser ghstack dependencies: pytorch#113031
…13091) Fixes pytorch#112845 Pull Request resolved: pytorch#113091 Approved by: https://github.com/jbschlosser ghstack dependencies: pytorch#113031, pytorch#112519
…ded_tensor msg (pytorch#113162) Improvements: improves to_padded_tensor error message when passed a NT with zero numel Pull Request resolved: pytorch#113162 Approved by: https://github.com/jbschlosser ghstack dependencies: pytorch#113031, pytorch#112519, pytorch#113091
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). [ghstack-poisoned]
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). Pull Request resolved: #131945 Approved by: https://github.com/davidberard98, https://github.com/jananisriram
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). Pull Request resolved: pytorch#131945 Approved by: https://github.com/davidberard98, https://github.com/jananisriram
This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). Pull Request resolved: pytorch#131945 Approved by: https://github.com/davidberard98, https://github.com/jananisriram
Stack from ghstack (oldest at bottom):