-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Initial vmap + NT support with unbind fallback #106786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT [ghstack-poisoned]
|
Is this enough to make SAM work? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good. Next steps sound like we should try to get a batching rule working?
If we want to turn this POC to not a POC, then at some point we should:
- figure out what limitations we're putting on vmap + nestedtensor and then raise error messages when we go beyond them. E.g. maybe we only support a single vmap on a nestedtensor that has a single batch dimension.
- figure out if we're banning size/strides/ etc correctly on BatchedTensor(NestedTensor)
Almost, wrapping this up Edit: no (internal link) |
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT [ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT Restrictions: * Only supports one level of vmap * Only supports vmapping over dim=0 for NTs * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs [ghstack-poisoned]
|
@zou3519 I think there's just one thing outstanding wrt error messages: #106786 (comment)
AFAICT these are the same. After my changes, the error message for this example rightfully mentions 2D bounds: def f(x):
x.size(5)
return x
x = torch.randn(2, 3, 4)
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 5)
output = vmap(f)(x) |
Thanks for checking. I think we're good here then, let me give this a quick re-read |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge -f "ignore spurious failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT Restrictions: * Only supports one level of vmap * Only supports vmapping over dim=0 for NTs * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs [ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT Restrictions: * Only supports one level of vmap * Only supports vmapping over dim=0 for NTs * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs [ghstack-poisoned]
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT Restrictions: * Only supports one level of vmap * Only supports vmapping over dim=0 for NTs * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
PoC demonstrating vmap + NT based on the design doc. This PR:
BatchedTensorImpls to contain NTsBatchedNestedTensordispatch key for NT-specific batching rulesRestrictions: