-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[AOTI] Support non auto-tuned triton kernels in aoti #113090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113090
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 036cbc8 with merge base b7acd37 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| for arg_name in self.ordered_kwargs_for_cpp_kernel: | ||
| v = self.get_kwargs_value(arg_name) | ||
| kwargs.append(V.graph.wrapper_code.val_to_arg_str(v)) | ||
| if isinstance(v, sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this but not sure if there's a better solution..
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@desertfire @chenyang78 mind reviewing this one? |
test/inductor/test_aot_inductor.py
Outdated
| @requires_cuda() | ||
| @skipIfRocm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we have if self.device != "cuda" check below, are these two redundant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triton is not installed on all machines, so in the past, if i didn't include the both of these, i ran into some symbol problems. I can try deleting them and seeing if anything fails
torch/_inductor/ir.py
Outdated
| if self.kwargs and not self.ordered_kwargs_for_cpp_kernel: | ||
| assert ( | ||
| self.ordered_kwargs_for_cpp_kernel | ||
| ), "ordered_kwargs_for_cpp_kernel is missing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the assert condition is always False? Probably we could just thrown an exception here?
| for arg_name in self.ordered_kwargs_for_cpp_kernel: | ||
| v = self.get_kwargs_value(arg_name) | ||
| kwargs.append(V.graph.wrapper_code.val_to_arg_str(v)) | ||
| if isinstance(v, sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a clause to handle sympy.Expr from CppWrapperCodeGen's val_to_arg_str so that we could make kwargs hold only strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think the point here is to use sympy.Expr as kwarg to codegen it at a later stage? cc @oulgen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if we converted to a string, then codegen will not be able to tell it was a sympy expr and emit it like a tensor. @chenyang78 The problem is codegen tries to figure out things based on their string representation but for sympy expr, it keeps them as is.. which is not great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the clarification. This is quite unfortunate...
| self.grid = grid | ||
|
|
||
| kernel, _ = self.get_kernel_and_configs() | ||
| self.ordered_kwargs_for_cpp_kernel = kernel.arg_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be pretty hacky to rely on ordered_kwargs_for_cpp_kernel to generate kernel args indirectly. Wondering if we could implement a codegen_args method for UserDefinedTritonKernel, where we return arg strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following. It is not possible to return all strings because then sympy expr can no longer be distinguished. Look at
pytorch/torch/_inductor/codegen/wrapper.py
Lines 2468 to 2501 in 562c4ae
| def generate_args_decl(self, call_args): | |
| dynamic_symbols = V.graph.sizevars.free_symbols() | |
| # TODO: only works for constant now, need type info | |
| new_args = [] | |
| for arg in call_args: | |
| var_name = f"var_{next(self.arg_var_id)}" | |
| if isinstance( | |
| arg, | |
| ( | |
| sympy.Integer, | |
| sympy.Symbol, | |
| SymbolicCallArg, | |
| ), | |
| ): | |
| self.writeline(f"auto {var_name} = {arg};") | |
| elif is_int(arg): | |
| self.writeline(f"int {var_name} = {arg};") | |
| elif is_float(arg): | |
| self.writeline(f"float {var_name} = {arg};") | |
| elif any(str(arg) == s.name for s in dynamic_symbols): | |
| self.writeline(f"auto {var_name} = {arg};") | |
| else: | |
| if config.aot_inductor.abi_compatible: | |
| self.writeline(f"CUdeviceptr {var_name};") | |
| self.writeline( | |
| f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));" | |
| ) | |
| else: | |
| self.writeline( | |
| f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());" | |
| ) | |
| new_args.append(f"&{var_name}") | |
| return ", ".join(new_args) |
| arg, | ||
| ( | ||
| sympy.Integer, | ||
| sympy.Expr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we can simply use {arg} in the f-string below if the arg is sympy.Expr. From what I remember, in C++ codegen we need to use self.expr_printer for expressions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
| for arg_name in self.ordered_kwargs_for_cpp_kernel: | ||
| v = self.get_kwargs_value(arg_name) | ||
| kwargs.append(V.graph.wrapper_code.val_to_arg_str(v)) | ||
| if isinstance(v, sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think the point here is to use sympy.Expr as kwarg to codegen it at a later stage? cc @oulgen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, with a few nits. Thanks!
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| for arg_name in self.ordered_kwargs_for_cpp_kernel: | ||
| v = self.get_kwargs_value(arg_name) | ||
| kwargs.append(V.graph.wrapper_code.val_to_arg_str(v)) | ||
| if isinstance(v, sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the clarification. This is quite unfortunate...
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler