KEMBAR78
[Inductor] Support user defined triton kernels in inductor by oulgen · Pull Request #111434 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111434

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4752953 with merge base bf01a7b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
def test_triton_kernel_native(self, grad, backend):
if backend == "inductor" and grad is True:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 @bdhirsh

This is surprising. It fails with

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

in P857566445.

Looking at it with a debugger, we are calling copy_ on a faketensor that requires grad. Looks like aot_eager does not execute this codepath.

device_interface.Worker.set_device(device.index)
kernel = TritonCodeCache.load(kernel_name, source_code)
kernel.precompile(warm_cache_only_with_cc=cc)
if hasattr(kernel, "precompile"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel It looks like all the existing kernels are CachingAutotuner but in my case I end up with JitFunction from triton. I assume in order to support @triton.autotune I'm gonna need to make this work with cachingautotuner but for the time being is this fine?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use JitFunction from Triton, we want parallel ahead of time compiles.

You should be able to do something similar to:

def template(num_stages, num_warps, meta, filename=None):
"""
Compile a triton template
"""
return cached_autotune(
None,
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
meta=meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)

And put a @template above the generated Triton kernel. You will need to generate proper meta.

patterns.apply(gm.graph)
if is_inference:
inference_patterns.apply(gm.graph)
triton_patterns.apply(gm.graph)
Copy link
Contributor Author

@oulgen oulgen Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clones are removed as part of remove_noop_ops. Looking at the output code, I think the clones are still there, so i want to check whether to run remove_noop_ops again or move this check above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to do this as part of reinplace_scatters.

  1. The clone removal pass doesn't handle inplace mutations.
  2. reinplace_scatters is exactly for this purpose - it converts scatter (i.e. clone + scatter_) into just scatter_. It also converts scatter + copy_ into just scatter_ as well.

@oulgen
Copy link
Contributor Author

oulgen commented Oct 17, 2023

Putting up mostly as an RFC for now, on the next PR I will implement multiple kernels in the same function as well as kernels calling each other.

Example output: P857575342

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 17, 2023
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 19, 2023
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
# It might be better to move this decomposition into lowering after some
# sort of clone removal pass at IR level is implemented. For the time being,
# decomposition is done at this level in order take advantage of
# existing clone removal.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see - is inductor's clone removal pass done during lowering time, so if we desugar the functional triton_kernel op into mutable op + clones, it will be "too late"?

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 19, 2023
@oulgen
Copy link
Contributor Author

oulgen commented Oct 19, 2023

Chatting with @Chillee we decided to move the decomposition to reinplace_scatters pass. I'm moving that out of this PR and will do as follow up.

This PR should be ready for review now

def codegen(self, wrapper):
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

kernel = kernel_side_table.get_kernel(self.kernel_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... maybe a little late, but @Chillee I don't think that inductor's fx graph caching TORCHINDUCTOR_FX_GRAPH_CACHE=1 will play well with this kernel_side_table.

What do you think - should we just make sure not to add to the cache if our graph uses any user triton kernels, to unblock, or fully revisit the side table?

It seems to me like the FX graph that we get (both from dynamo and AOTAutograd) will burn in a random index that's supposed to map to a user triton kernel that dynamo saw:

def forward(self, x_1, output_1):
    triton_kernel_wrapper_functional_proxy = torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional(kernel_idx = 0, grid = (5,), kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16});  x_1 = #
    getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']
    getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']
    getitem_2 = triton_kernel_wrapper_functional_proxy['n_elements']
    getitem_3 = triton_kernel_wrapper_functional_proxy['BLOCK_SIZE'];  triton_kernel_wrapper_functional_proxy = None
    return getitem_1

But if we were to map this graph to a cached, compiled inductor graph, we have no guarantee that that index will map to the same triton kernel.

Maybe we can have the cache key also include the kernel_side_table's mapping, from index to triton kernel (or a hash of its source code)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is why I was suggesting that it might be better to just include the Triton source code in the operator definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would not only need to store the source code of one function but also every single dependancy. Another downside of that is the FX graph would be impossible to read.

If this is the preferred solution, I'm happy to change it but IMO caching kernel_side_table's mapping seems like a cleaner solution.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
device_interface.Worker.set_device(device.index)
kernel = TritonCodeCache.load(kernel_name, source_code)
kernel.precompile(warm_cache_only_with_cc=cc)
if hasattr(kernel, "precompile"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use JitFunction from Triton, we want parallel ahead of time compiles.

You should be able to do something similar to:

def template(num_stages, num_warps, meta, filename=None):
"""
Compile a triton template
"""
return cached_autotune(
None,
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
meta=meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)

And put a @template above the generated Triton kernel. You will need to generate proper meta.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@oulgen oulgen requested a review from jansel October 22, 2023 03:39
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@oulgen
Copy link
Contributor Author

oulgen commented Oct 22, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 22, 2023
… reinplace_scatters (#111627)

Pull Request resolved: #111627
Approved by: https://github.com/jansel
ghstack dependencies: #111434
@facebook-github-bot facebook-github-bot deleted the gh/oulgen/7/head branch October 26, 2023 14:24
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants