KEMBAR78
Initial vmap + NT support with unbind fallback by jbschlosser · Pull Request #106786 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Aug 8, 2023

Stack from ghstack (oldest at bottom):

PoC demonstrating vmap + NT based on the design doc. This PR:

  • Allows BatchedTensorImpls to contain NTs
  • Introduces a BatchedNestedTensor dispatch key for NT-specific batching rules
  • Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:

  • Only supports one level of vmap
  • Only supports vmapping over dim=0 for NTs
    • For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106786

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures, 1 Unrelated Failure

As of commit c53b848 with merge base e58d3ed (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 8, 2023
ghstack-source-id: a18e4b6
Pull Request resolved: #106786
@ezyang
Copy link
Contributor

ezyang commented Aug 8, 2023

Is this enough to make SAM work?

@ezyang ezyang requested a review from zou3519 August 8, 2023 15:52
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good. Next steps sound like we should try to get a batching rule working?

If we want to turn this POC to not a POC, then at some point we should:

  • figure out what limitations we're putting on vmap + nestedtensor and then raise error messages when we go beyond them. E.g. maybe we only support a single vmap on a nestedtensor that has a single batch dimension.
  • figure out if we're banning size/strides/ etc correctly on BatchedTensor(NestedTensor)

@jbschlosser
Copy link
Contributor Author

jbschlosser commented Aug 8, 2023

Is this enough to make SAM work?

Almost, wrapping this up

Edit: no (internal link)

@zou3519 zou3519 requested a review from kshitij12345 August 10, 2023 15:05
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

[ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

@zou3519 I think there's just one thing outstanding wrt error messages: #106786 (comment)

the error messages after we change to sizes_custom/strides_custom for the non-nestedtensor case. Ideally these would be the same before/after this PR

AFAICT these are the same. After my changes, the error message for this example rightfully mentions 2D bounds:

def f(x):
    x.size(5)
    return x

x = torch.randn(2, 3, 4)

# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 5)
output = vmap(f)(x)

@zou3519
Copy link
Contributor

zou3519 commented Aug 30, 2023

AFAICT these are the same. After my changes, the error message for this example rightfully mentions 2D bounds:

Thanks for checking. I think we're good here then, let me give this a quick re-read

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 31, 2023
@jbschlosser jbschlosser added topic: not user facing topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Aug 31, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -f "ignore spurious failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x b9c5dfbea029058de432735fa801da06c081ca41 returned non-zero exit code 1

Auto-merging aten/src/ATen/functorch/BatchedTensorImpl.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/functorch/BatchedTensorImpl.cpp
Auto-merging aten/src/ATen/functorch/BatchedTensorImpl.h
Auto-merging test/functorch/test_vmap.py
Auto-merging torch/csrc/functorch/init.cpp
Auto-merging torchgen/model.py
error: could not apply b9c5dfbea02... Initial vmap + NT support with unbind fallback
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
Details for Dev Infra team Raised by workflow job

PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs

[ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs

[ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/jbschlosser/87/head branch September 10, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants