-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add utility to get computed kernel in torch.library #158393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add utility to get computed kernel in torch.library #158393
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158393
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 172857e with merge base 34ec5ed ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 cc albanD [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds pretty good!
Only small questions!
torch/library.py
Outdated
| op = op._name | ||
|
|
||
| if isinstance(dispatch_key, str): | ||
| dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the error you get when passing a wrong dispatch key here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, added proper error handling here
| auto [annotatedKernel, _] = | ||
| computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); | ||
|
|
||
| return SafeKernelFunction(&annotatedKernel.kernel); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to grab the debug string here and add that to the __repr__ we get from python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives something like
SafeKernelFunction(debug='registered at /data/users/mg1998/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:2309')
Do you think that is meaningful enough or should we add more info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is 100% super useful. These error messages saved me a few times for multiple-registration errors!
And for python users, it should point to their code directly. Which is even better so they know which function this is!
|
|
||
| // List of tokens that need to be invalidated when this KernelFunction is | ||
| // destroyed | ||
| mutable std::vector<std::weak_ptr<KernelToken>> tokens_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why mutable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why weak_ptr and not shared_ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why mutable
I think this is necessary in order to make registerToken const, which was in turn needed to allow SafeKernelFunction to take in const KernelFunction*, removing this would necessitate const_cast-ing the annotatedKernel.kernel in getComputedKernelForDispatchKey, wdyt
why weak_ptr and not shared_ptr
What's the benefit of shared_ptr over weak_ptr here? if we use weak_ptr, the KernelToken dies with the SafeKernelFunction, which I think is what we want to achieve here (?)
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
Thanks!
|
@pytorchbot merge -r |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 Pull Request resolved: #158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 Pull Request resolved: #158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to pytorch#155330 Pull Request resolved: pytorch#158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to pytorch#155330 Pull Request resolved: pytorch#158393 Approved by: https://github.com/albanD
Adds
OperatorEntry::getComputedKernelForDispatchKeywhich returns the KernelFunction corresponding toOperatorEntry.dispatchTable_[dispatch_ix]for a given dispatch keySafeKernelFunctionthat holds aKernelToken. ThisKernelTokenis registered to theKernelFunctioninOperatorEntry.kernels_and will be invalidated when theKernelFunctionis destructed (i.e. when theAnnotatedKernelthat holds thisKernelFunctionis removed fromkernels_, which happens when the corresponding impl is deregistered).SafeKernelFunctioncan be called viacallBoxed, the validity of the token will be checked before this happensSafeKernelFunctionis pybinded andgetComputedKernelForDispatchKeyis exposed to the frontend iatorch.library.get_kernelRelated to #155330
Stack from ghstack (oldest at bottom):
cc @albanD