-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
@chsasank opened this issue in triton-lang/triton#150 and I thought this was pretty interesting. It's true that to actually ship this in PyTorch mainline, we would indeed want an offline Triton compiler, and we'd also need to work out all of the distribution shenanigans that a project as big as PyTorch would have to do.
However, I think it is feasible to experimentally prototype Triton versions of PyTorch kernels out of tree, with only a few extra hooks in PyTorch core. In #62660 I give a prototype for how to directly register Python implementations of kernels, allowing us to override preexisting CUDA implementations. So here's how it would work:
- You write a Triton kernel to replace some kernel in PyTorch
- You call the Python operator registration API, and override the existing CUDA kernel with your new Triton kernel
- Profit!
There's a little example of how to override existing kernels in https://github.com/pytorch/pytorch/pull/62660/files#diff-415017bcad4fa6cd6d3dfe5f6ea1caffcd7122b46b8c1e4825f7d889efc80a62 (the API needs improvement). To make it work with Triton, you would just call into Triton in the same way as the tutorial. I could probably produce a working end to end example, but I'll let whatever intrepid soul who wants to embark on this project have the fun of figuring that out.
If someone is actually interested in setting up the repo to do this and charging ahead, I'll prioritize getting #62660 into PyTorch proper.