KEMBAR78
[aoti][mps] Dynamic reductions by angelayi · Pull Request #159355 · pytorch/pytorch · GitHub
Skip to content

Conversation

@angelayi
Copy link
Contributor

@angelayi angelayi commented Jul 29, 2025

Stack from ghstack (oldest at bottom):

Dynamic kernel:

[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10::metal::threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl

Static kernel:

kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159355

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit ecaeb9a with merge base bb62e1f (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

angelayi added a commit that referenced this pull request Jul 29, 2025
ghstack-source-id: 7676f37
Pull-Request-resolved: #159355
@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) module: inductor labels Jul 29, 2025
@angelayi angelayi marked this pull request as draft July 29, 2025 08:39
@angelayi angelayi added the keep-going Don't stop on first failure, keep running tests until the end label Jul 29, 2025
[ghstack-poisoned]
angelayi added a commit that referenced this pull request Jul 29, 2025
ghstack-source-id: 5718d47
Pull-Request-resolved: #159355
@angelayi angelayi added the topic: not user facing topic category label Jul 30, 2025
[ghstack-poisoned]
angelayi added a commit that referenced this pull request Jul 30, 2025
ghstack-source-id: a8bcb13
Pull-Request-resolved: #159355
[ghstack-poisoned]
[ghstack-poisoned]
@angelayi angelayi requested review from desertfire and malfet July 30, 2025 16:32
@angelayi angelayi changed the title [wip][aoti][mps] Dynamic reductions [aoti][mps] Dynamic reductions Jul 30, 2025
@angelayi angelayi marked this pull request as ready for review July 30, 2025 16:33
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but see some comments

def load(self, name: str, index: sympy.Expr) -> CSEVariable:
"""Codegen a load from an InputBuffer"""
var = self.args.input(name)
index = self.prepare_indexing(index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename this one to index_str? It supposed to return CSEVariable, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oo yes youre right

kernel_name = f"{mps_lib_name}_func"
else:
kernel_name = f"{mps_lib_name}.generated_kernel"
kernel_name = f"{mps_lib_name}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, how this is related to dynamic shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah! So when generating the dynamic input variables there's some logic in wrapper to do f"{kernel_name}_xnumel". Previously the kernel_name would be like mps_lib_0.generated_kernel which would result in the variable name "mps_lib_0.generated_kernel_xnumel", which is an invalid name because of the period. So I changed the kernel_name to just be "mps_lib_0" so that the variable names would be "mps_lib_0_xnumel"

Comment on lines +2544 to +2552
if not triton:
if device.type == "cpu":
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
elif device.type == "mps":
# TODO: Fix me, MPS does not expose streams now
self.writeline(
self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args)
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unrelated to dynamic shapes is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above change since I changed the kernel name from "mps_lib_0.generate_kernel" to be "mps_lib_0", when we call the kernel we now need to add the ".generated_kernel" so that the callsite is correct.

[ghstack-poisoned]
angelayi added a commit that referenced this pull request Jul 30, 2025
ghstack-source-id: 2a46e4c
Pull-Request-resolved: #159355
@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 31, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x fbd37d1658bf436b160208d6b5d9cd2602d5198c returned non-zero exit code 1

Auto-merging test/inductor/test_aot_inductor.py
Auto-merging test/inductor/test_torchinductor.py
Auto-merging torch/_inductor/codegen/mps.py
CONFLICT (content): Merge conflict in torch/_inductor/codegen/mps.py
Auto-merging torch/_inductor/codegen/wrapper.py
error: could not apply fbd37d1658b... [aoti][mps] Dynamic reductions
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10::metal::threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Static kernel:
```cpp
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Pull Request resolved: #159355
Approved by: https://github.com/malfet
mlazos added a commit that referenced this pull request Aug 4, 2025
…kage"

Fixes issues introduced by #159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Aug 4, 2025
Fixes issues introduced by #159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 4, 2025
Fixes issues introduced by #159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?

Pull Request resolved: #159760
Approved by: https://github.com/angelayi
mlazos added a commit that referenced this pull request Aug 5, 2025
…kage"

Fixes issues introduced by #159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Aug 5, 2025
Fixes issues introduced by #159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@github-actions github-actions bot deleted the gh/angelayi/106/head branch August 31, 2025 02:16
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fixes issues introduced by pytorch#159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?

Pull Request resolved: pytorch#159760
Approved by: https://github.com/angelayi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: inductor topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants