KEMBAR78
Add utility to get computed kernel in torch.library by mikaylagawarecki · Pull Request #158393 · pytorch/pytorch · GitHub
Skip to content

Conversation

@mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Jul 15, 2025

Adds OperatorEntry::getComputedKernelForDispatchKey which returns the KernelFunction corresponding to OperatorEntry.dispatchTable_[dispatch_ix] for a given dispatch key

  • Specifically it returns a SafeKernelFunction that holds a KernelToken. This KernelToken is registered to the KernelFunction in OperatorEntry.kernels_ and will be invalidated when the KernelFunction is destructed (i.e. when the AnnotatedKernel that holds this KernelFunction is removed from kernels_, which happens when the corresponding impl is deregistered).
  • SafeKernelFunction can be called via callBoxed, the validity of the token will be checked before this happens
  • SafeKernelFunction is pybinded and getComputedKernelForDispatchKey is exposed to the frontend ia torch.library.get_kernel

Related to #155330

Stack from ghstack (oldest at bottom):

cc @albanD

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158393

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 172857e with merge base 34ec5ed (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mikaylagawarecki added a commit that referenced this pull request Jul 15, 2025
ghstack-source-id: c7793b0
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki marked this pull request as draft July 16, 2025 14:44
mikaylagawarecki added a commit that referenced this pull request Jul 16, 2025
ghstack-source-id: 439450f
Pull Request resolved: #158393
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: e876743
Pull Request resolved: #158393
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: 762f9b6
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki added module: python frontend For issues relating to PyTorch's Python frontend topic: new features topic category labels Jul 24, 2025
@mikaylagawarecki mikaylagawarecki requested a review from albanD July 24, 2025 21:35
@mikaylagawarecki mikaylagawarecki added release notes: python_frontend python frontend release notes category and removed module: python frontend For issues relating to PyTorch's Python frontend labels Jul 24, 2025
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review July 24, 2025 21:37
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: 5925282
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki requested a review from zou3519 July 25, 2025 14:48
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds pretty good!
Only small questions!

torch/library.py Outdated
op = op._name

if isinstance(dispatch_key, str):
dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you get when passing a wrong dispatch key here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, added proper error handling here

auto [annotatedKernel, _] =
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);

return SafeKernelFunction(&annotatedKernel.kernel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to grab the debug string here and add that to the __repr__ we get from python?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives something like

SafeKernelFunction(debug='registered at /data/users/mg1998/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:2309')

Do you think that is meaningful enough or should we add more info

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is 100% super useful. These error messages saved me a few times for multiple-registration errors!
And for python users, it should point to their code directly. Which is even better so they know which function this is!


// List of tokens that need to be invalidated when this KernelFunction is
// destroyed
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why mutable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why weak_ptr and not shared_ptr?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mutable

I think this is necessary in order to make registerToken const, which was in turn needed to allow SafeKernelFunction to take in const KernelFunction*, removing this would necessitate const_cast-ing the annotatedKernel.kernel in getComputedKernelForDispatchKey, wdyt

why weak_ptr and not shared_ptr

What's the benefit of shared_ptr over weak_ptr here? if we use weak_ptr, the KernelToken dies with the SafeKernelFunction, which I think is what we want to achieve here (?)

Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 1c15180
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki requested a review from albanD August 5, 2025 19:32
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: ec2351a
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 72c37fb
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 8, 2025
ghstack-source-id: a4e17da
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!
Thanks!

@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/mikaylagawarecki/320/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/158393)

pytorchmergebot pushed a commit that referenced this pull request Aug 13, 2025
ghstack-source-id: 8f833e2
Pull Request resolved: #158393
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330

Pull Request resolved: #158393
Approved by: https://github.com/albanD
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330

Pull Request resolved: #158393
Approved by: https://github.com/albanD
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to pytorch#155330

Pull Request resolved: pytorch#158393
Approved by: https://github.com/albanD
@github-actions github-actions bot deleted the gh/mikaylagawarecki/320/head branch September 13, 2025 02:06
pragupta added a commit to ROCm/pytorch that referenced this pull request Sep 16, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to pytorch#155330

Pull Request resolved: pytorch#158393
Approved by: https://github.com/albanD
pragupta added a commit to ROCm/pytorch that referenced this pull request Sep 17, 2025
pragupta added a commit to ROCm/pytorch that referenced this pull request Sep 24, 2025
jataylo added a commit to jataylo/pytorch that referenced this pull request Oct 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend python frontend release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants