KEMBAR78
[NT] Backward support for broadcasting binary ops by soulitzer · Pull Request #112519 · pytorch/pytorch · GitHub
Skip to content

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Oct 31, 2023

@soulitzer soulitzer requested a review from albanD as a code owner October 31, 2023 17:32
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112519

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 08d2e23 with merge base 1855153 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! got some minor comments and questions

Comment on lines 578 to 579
# sum_dim_IntList can produce a NT or a T depending on whether the ragged dim
# is summed over
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I'm a little wary of this, although it does make sense. I believe this is the first op operating on NTs that can produce a dense T. In general, we should have consistent semantics for ops that reduce out the raggedness, whichever way we decide. @cpuhrsch any opinion on if conditionally returning a dense T here is the way to go?

If we return a T here and the user wants to go back to NT land, a non-copying as_nested_tensor(t) seems useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think returning plain T has more known benefits, and it's not clear if returning NT today would give us more flexibility wrt bc either. If we return NT today and make users to grab values explicitly, that would no longer work if we choose to return T later.

Pros:

  • avoid having additional state on NT tracking whether we are a dense NT or not
  • less additional logic within autograd to handle the conversion from dense NT to dense
  • users would not have to worry about having to manually converting the dense NT back to dense to avoid subclass overhead

Cons:

  • some ops on NT unexpectedly(?) return non-NTs. I'm not actually sure how problematic it is. On the NT front, as long as NTs can freely promote, doing any NT related ops should not be problematic on this dense T.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline discussion @albanD thinks this seems fine as long as there are no silent correctness issues.

Copy link
Contributor

@jbschlosser jbschlosser Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool I'm good with it then :) I'm finishing up an impl of as_nested_tensor(t) so that should help as well.

Comment on lines 102 to 109
t_shape = t.shape
extra = 0
for s in t_shape:
if s == 1:
if t.dim() > nt.dim():
extra += 1
t = t.squeeze(0)
return t, extra
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this all work within torch.compile without unwanted guards?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm let me check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we do want the extra guards here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm okay, which guards show up?

Copy link
Contributor Author

@soulitzer soulitzer Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we do want the extra guards here.

Whoops, this is wrong actually. What is happening is that there ARE always guards testing whether inputs given are zero/one or not (due to zero-one specialization) which is indeed what we want, but also, as a result of that there are also no "extra" guards from this s == 1 test here.

So there are actually two cases: (1) the case where the inputs are also the inputs to the entire program, and (2) the case where the inputs are not.

In (1), what I wrote about zero-one specialization holds. But otherwise, the "one" would've needed to been created somewhere, and likely not a symint, so no extra guards either way.

# ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
nt, t = (a, b) if a_is_nt else (b, a)
# See Note: [ Squeezing leading ones ]
t_squeezed, extra = squeeze_leading_ones_get_extra(t, nt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb Q: can you explain a bit more the purpose behind this change? I'm unclear as to what it's solving in the context of this PR

Copy link
Contributor Author

@soulitzer soulitzer Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we were too relaxed when it comes to checking whether the pre-existing broadcasting logic is valid for the given NTs. e.g. someone could try JT: [B, *, D] + T: [sum(*), D] and still go through the eays path and produce something as output, even though broadcasting didn't really make sense in the first place.

What we kind of what to do is to make sure that the JT has a dim that is two greater than that of the other tensor, so that we would have something like (B, j0, a0, ..., an) + (b0, ..., bm) where m < n where B and j0 are always broadcasted over uniformly (as in any value in a given batch acquires the same values) during broadcasting logic. This is important because otherwise we would fall into the unbind case below.

We want this squeeze_leading_ones_get_extra helper because naively doing nt.dim() > t.dim() + 2 doesn't quite work though because the t might have leading ones in its dim that shouldn't disqualify it from the easy case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks for the explanation :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if we see t.dim() > nt.dim(), we know we're in an unsupported case, even if t is all ones. If we bail out early for this, it might be a little more readable.

soulitzer added a commit that referenced this pull request Nov 3, 2023
ghstack-source-id: 487d5e4
Pull Request resolved: #112519
soulitzer added a commit that referenced this pull request Nov 6, 2023
ghstack-source-id: 9391c09
Pull Request resolved: #112519
@soulitzer
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / win-vs2019-cuda11.8-py3 / build

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
pytorchmergebot pushed a commit that referenced this pull request Nov 7, 2023
pytorchmergebot pushed a commit that referenced this pull request Nov 7, 2023
…ded_tensor msg (#113162)

Improvements: improves to_padded_tensor error message when passed a NT with zero numel

Pull Request resolved: #113162
Approved by: https://github.com/jbschlosser
ghstack dependencies: #113031, #112519, #113091
@facebook-github-bot facebook-github-bot deleted the gh/soulitzer/247/head branch November 10, 2023 15:24
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ded_tensor msg (pytorch#113162)

Improvements: improves to_padded_tensor error message when passed a NT with zero numel

Pull Request resolved: pytorch#113162
Approved by: https://github.com/jbschlosser
ghstack dependencies: pytorch#113031, pytorch#112519, pytorch#113091
jbschlosser added a commit that referenced this pull request Jul 30, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 30, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 30, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 30, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 31, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 31, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 8, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 8, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 13, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 13, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: #131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: pytorch#131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in pytorch#112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: pytorch#131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants