-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inductor] Support user defined triton kernels in inductor #111434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111434
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4752953 with merge base bf01a7b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/dynamo/test_functions.py
Outdated
| @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) | ||
| @patch.object(torch._inductor.config, "implicit_fallbacks", False) | ||
| def test_triton_kernel_native(self, grad, backend): | ||
| if backend == "inductor" and grad is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch/_inductor/codecache.py
Outdated
| device_interface.Worker.set_device(device.index) | ||
| kernel = TritonCodeCache.load(kernel_name, source_code) | ||
| kernel.precompile(warm_cache_only_with_cc=cc) | ||
| if hasattr(kernel, "precompile"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel It looks like all the existing kernels are CachingAutotuner but in my case I end up with JitFunction from triton. I assume in order to support @triton.autotune I'm gonna need to make this work with cachingautotuner but for the time being is this fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use JitFunction from Triton, we want parallel ahead of time compiles.
You should be able to do something similar to:
pytorch/torch/_inductor/triton_heuristics.py
Lines 1115 to 1125 in c84c86f
| def template(num_stages, num_warps, meta, filename=None): | |
| """ | |
| Compile a triton template | |
| """ | |
| return cached_autotune( | |
| None, | |
| [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], | |
| meta=meta, | |
| heuristic_type=HeuristicType.TEMPLATE, | |
| filename=filename, | |
| ) |
And put a @template above the generated Triton kernel. You will need to generate proper meta.
| patterns.apply(gm.graph) | ||
| if is_inference: | ||
| inference_patterns.apply(gm.graph) | ||
| triton_patterns.apply(gm.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clones are removed as part of remove_noop_ops. Looking at the output code, I think the clones are still there, so i want to check whether to run remove_noop_ops again or move this check above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to do this as part of reinplace_scatters.
- The clone removal pass doesn't handle inplace mutations.
reinplace_scattersis exactly for this purpose - it convertsscatter(i.e. clone + scatter_) into justscatter_. It also convertsscatter + copy_into justscatter_as well.
|
Putting up mostly as an RFC for now, on the next PR I will implement multiple kernels in the same function as well as kernels calling each other. Example output: P857575342 |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| # It might be better to move this decomposition into lowering after some | ||
| # sort of clone removal pass at IR level is implemented. For the time being, | ||
| # decomposition is done at this level in order take advantage of | ||
| # existing clone removal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see - is inductor's clone removal pass done during lowering time, so if we desugar the functional triton_kernel op into mutable op + clones, it will be "too late"?
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
Chatting with @Chillee we decided to move the decomposition to reinplace_scatters pass. I'm moving that out of this PR and will do as follow up. This PR should be ready for review now |
| def codegen(self, wrapper): | ||
| from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table | ||
|
|
||
| kernel = kernel_side_table.get_kernel(self.kernel_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... maybe a little late, but @Chillee I don't think that inductor's fx graph caching TORCHINDUCTOR_FX_GRAPH_CACHE=1 will play well with this kernel_side_table.
What do you think - should we just make sure not to add to the cache if our graph uses any user triton kernels, to unblock, or fully revisit the side table?
It seems to me like the FX graph that we get (both from dynamo and AOTAutograd) will burn in a random index that's supposed to map to a user triton kernel that dynamo saw:
def forward(self, x_1, output_1):
triton_kernel_wrapper_functional_proxy = torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional(kernel_idx = 0, grid = (5,), kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16}); x_1 = #
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']
getitem_2 = triton_kernel_wrapper_functional_proxy['n_elements']
getitem_3 = triton_kernel_wrapper_functional_proxy['BLOCK_SIZE']; triton_kernel_wrapper_functional_proxy = None
return getitem_1
But if we were to map this graph to a cached, compiled inductor graph, we have no guarantee that that index will map to the same triton kernel.
Maybe we can have the cache key also include the kernel_side_table's mapping, from index to triton kernel (or a hash of its source code)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is why I was suggesting that it might be better to just include the Triton source code in the operator definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would not only need to store the source code of one function but also every single dependancy. Another downside of that is the FX graph would be impossible to read.
If this is the preferred solution, I'm happy to change it but IMO caching kernel_side_table's mapping seems like a cleaner solution.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
torch/_inductor/codecache.py
Outdated
| device_interface.Worker.set_device(device.index) | ||
| kernel = TritonCodeCache.load(kernel_name, source_code) | ||
| kernel.precompile(warm_cache_only_with_cc=cc) | ||
| if hasattr(kernel, "precompile"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use JitFunction from Triton, we want parallel ahead of time compiles.
You should be able to do something similar to:
pytorch/torch/_inductor/triton_heuristics.py
Lines 1115 to 1125 in c84c86f
| def template(num_stages, num_warps, meta, filename=None): | |
| """ | |
| Compile a triton template | |
| """ | |
| return cached_autotune( | |
| None, | |
| [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], | |
| meta=meta, | |
| heuristic_type=HeuristicType.TEMPLATE, | |
| filename=filename, | |
| ) |
And put a @template above the generated Triton kernel. You will need to generate proper meta.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… reinplace_scatters (#111627) Pull Request resolved: #111627 Approved by: https://github.com/jansel ghstack dependencies: #111434
…11434) Pull Request resolved: pytorch#111434 Approved by: https://github.com/jansel
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
…11434) Pull Request resolved: pytorch#111434 Approved by: https://github.com/jansel
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler