-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[aoti][mps] Dynamic reductions #159355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aoti][mps] Dynamic reductions #159355
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159355
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit ecaeb9a with merge base bb62e1f ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but see some comments
| def load(self, name: str, index: sympy.Expr) -> CSEVariable: | ||
| """Codegen a load from an InputBuffer""" | ||
| var = self.args.input(name) | ||
| index = self.prepare_indexing(index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why rename this one to index_str? It supposed to return CSEVariable, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oo yes youre right
| kernel_name = f"{mps_lib_name}_func" | ||
| else: | ||
| kernel_name = f"{mps_lib_name}.generated_kernel" | ||
| kernel_name = f"{mps_lib_name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, how this is related to dynamic shapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah! So when generating the dynamic input variables there's some logic in wrapper to do f"{kernel_name}_xnumel". Previously the kernel_name would be like mps_lib_0.generated_kernel which would result in the variable name "mps_lib_0.generated_kernel_xnumel", which is an invalid name because of the period. So I changed the kernel_name to just be "mps_lib_0" so that the variable names would be "mps_lib_0_xnumel"
| if not triton: | ||
| if device.type == "cpu": | ||
| self.writeline(self.wrap_kernel_call(kernel_name, call_args)) | ||
| elif device.type == "mps": | ||
| # TODO: Fix me, MPS does not expose streams now | ||
| self.writeline( | ||
| self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args) | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unrelated to dynamic shapes is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the above change since I changed the kernel name from "mps_lib_0.generate_kernel" to be "mps_lib_0", when we call the kernel we now need to add the ".generated_kernel" so that the callsite is correct.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
constant long& r0_numel,
uint2 thread_pos [[thread_position_in_grid]],
uint2 group_pos [[thread_position_in_threadgroup]]
) {
auto xindex = thread_pos.x;
auto r0_index = thread_pos.y;
int x0 = xindex;
threadgroup float tmp_acc_0[32];
float tmp_acc_1 = 0;
for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
int r0_1 = 1024 * r0_1_cnt + r0_index;
if (r0_1 >= r0_numel) break;
auto tmp0 = in_ptr0[x0 + 5*r0_1];
tmp_acc_1 += tmp0;
}
auto tmp1 = c10::metal::threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}
void AOTInductorModel::run_impl(...) {
...
auto arg0_1_size = arg0_1.sizes();
int64_t s77 = arg0_1_size[0];
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {5LL, };
static constexpr int64_t int_array_1[] = {1LL, };
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
mps_lib_0_func->runCommandBlock([&] {
mps_lib_0_func->startEncoding();
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});
});
arg0_1.reset();
output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```
Static kernel:
```cpp
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
int x0 = xindex;
auto tmp0 = in_ptr0[x0];
auto tmp1 = in_ptr0[5 + x0];
auto tmp3 = in_ptr0[10 + x0];
auto tmp5 = in_ptr0[15 + x0];
auto tmp2 = tmp0 + tmp1;
auto tmp4 = tmp2 + tmp3;
auto tmp6 = tmp4 + tmp5;
out_ptr0[x0] = static_cast<float>(tmp6);
}
void AOTInductorModel::run_impl(...) {
...
static constexpr int64_t int_array_0[] = {5LL, };
static constexpr int64_t int_array_1[] = {1LL, };
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
mps_lib_0_func->runCommandBlock([&] {
mps_lib_0_func->startEncoding();
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});
});
arg0_1.reset();
output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```
Pull Request resolved: #159355
Approved by: https://github.com/malfet
…kage" Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? Pull Request resolved: #159760 Approved by: https://github.com/angelayi
…kage" Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Fixes issues introduced by pytorch#159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? Pull Request resolved: pytorch#159760 Approved by: https://github.com/angelayi
Stack from ghstack (oldest at bottom):
Dynamic kernel:
Static kernel:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben