-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ROCm][Inductor][CK] Add ck-tile based universal gemm kernels to torch.mm autotune choices #152341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152341
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit d36ad37 with merge base 72a3c8d ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces initial support for CK tile universal GEMM by adding new template functions and integrating them into the GEMM selection flow.
- Added new utility functions (use_ck_tile_template and use_ck_tile_gemm_template) in torch/_inductor/utils.py.
- Updated torch/_inductor/kernel/mm.py to include CK tile choices in the GEMM tuning flow.
- Introduced the CKTileTemplate class in torch/_inductor/codegen/rocm/ck_tile_template.py to generate CK tile–specific code.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
torch/_inductor/utils.py | Added utility functions to enable CK tile templates. |
torch/_inductor/kernel/mm.py | Updated GEMM choice logic to consider CK tile templates. |
torch/_inductor/codegen/rocm/ck_tile_template.py | New file implementing the CKTileTemplate for ROCm. |
d52baca
to
54ee7d8
Compare
@pytorchbot label 'topic: not user facing' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for CK-tile based universal GEMM kernels to the ROCm CK backend for Inductor by extending the autotuning choices and introducing a new code generation template.
- Adds a new CKTileGemmTemplate choice in the mm kernel dispatch.
- Updates CK universal GEMM template to handle dynamic input nodes in kBatch computation.
- Introduces a new file (ck_tile_template.py) for CK tile-based kernel code generation.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
torch/_inductor/kernel/mm.py | Adds import and invocation of CKTileGemmTemplate to extend autotune choices. |
torch/_inductor/codegen/rocm/ck_universal_gemm_template.py | Incorporates a dynamic input check in _get_kBatch to enforce a kBatch value of [1] for dynamic inputs. |
torch/_inductor/codegen/rocm/ck_tile_template.py | Provides a new template for generating CK-tile based universal GEMM kernel code. |
Comments suppressed due to low confidence (1)
torch/_inductor/codegen/rocm/ck_universal_gemm_template.py:890
- The early return for dynamic input nodes in _get_kBatch ensures a kBatch of [1]. Confirm that this is the intended behavior for all dynamic cases.
if is_dynamic(*self.input_nodes):
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) | ||
|
Copilot
AI
May 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider adding a brief comment explaining the usage and behavior of CKTileGemmTemplate.add_choices to help maintainers understand its relationship to the legacy CKGemmTemplate.
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) | |
# Add GEMM configurations using CKTileGemmTemplate. This template builds on the | |
# legacy CKGemmTemplate by introducing tiled GEMM operations for better performance | |
# in certain scenarios. It complements the configurations added by CKGemmTemplate. | |
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) |
Copilot uses AI. Check for mistakes.
@@ -0,0 +1,60 @@ | |||
import torch |
Copilot
AI
May 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider adding a file-level comment describing the purpose of CKTileTemplate to improve maintainability and assist new contributors.
Copilot uses AI. Check for mistakes.
hmm, I'd actually rather filter this out in torch rather than having to fail at runtime, but maybe the "check" can live in the rocm codebase?
can we make this another backend alltogether? I'd like to do something like = "TRITON,CK,CKTILE" if possible over time we can see whether that makes sense to keep but for now the extra knob will be helpful |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates CK-tile based universal gemm kernels into Inductor’s CK backend by adding a new code generation template and updating tuning configurations. The changes include:
- Adding a new CKTileGemmTemplate (in mm.py and a new file ck_tile_template.py) to expand autotune choices.
- Updating configuration keys (e.g. replacing n_max_profiling_configs with ck_max_profiling_configs) and compile command flags.
- Adjusting test settings (e.g. compile_threads and profiling configs) to validate the new kernel implementations.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
torch/_inductor/kernel/mm.py | Added CKTileGemmTemplate to autotune choices alongside legacy CK kernel selection. |
torch/_inductor/codegen/rocm/compile_command.py | Added the "-fvisibility=hidden" flag to the ROCm compiler options. |
torch/_inductor/codegen/rocm/ck_universal_gemm_template.py | Updated dynamic problem handling and replaced profiling configuration keys with ck_max_profiling_configs. |
torch/_inductor/codegen/rocm/ck_tile_template.py | New template for CK-tile based gemm kernels, providing necessary header and global setup. |
torch/_inductor/codegen/rocm/ck_conv_template.py | Updated profiling configuration keys for CK conv instances. |
test/inductor/test_ck_backend.py | Modified test parameters to leverage new configuration keys and increased compile threads. |
Comments suppressed due to low confidence (4)
torch/_inductor/kernel/mm.py:654
- [nitpick] Confirm that the ordering between the legacy CKGemmTemplate.add_ck_gemm_choices and the new CKTileGemmTemplate.add_choices is intentional and does not inadvertently affect the autotuning selection.
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2])
torch/_inductor/codegen/rocm/ck_universal_gemm_template.py:890
- [nitpick] The early return of [1] for dynamic input cases is a clear design decision; please ensure that this behavior aligns with overall runtime dispatch expectations.
if is_dynamic(*self.input_nodes):
torch/_inductor/codegen/rocm/ck_universal_gemm_template.py:950
- [nitpick] Ensure that replacing 'n_max_profiling_configs' with 'ck_max_profiling_configs' is consistently applied throughout the codebase and that related documentation is updated accordingly.
min(len(filtered_instances), config.rocm.ck_max_profiling_configs),
test/inductor/test_ck_backend.py:102
- [nitpick] The increased compile_threads value and updated profiling configuration keys in tests should be verified against typical test hardware to avoid potential resource over-allocation.
"compile_threads": 16,
Copilot
AI
May 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Please document the rationale and potential implications of adding the "-fvisibility=hidden" flag in this context to aid future maintainability.
"-fPIC", | |
"-fPIC", | |
# The -fvisibility=hidden flag is used to hide symbols by default in shared libraries. | |
# This reduces symbol export overhead, improves load times, and minimizes potential | |
# symbol conflicts. Symbols that need to be exported must be explicitly marked. |
Copilot uses AI. Check for mistakes.
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
1e126fa
to
859c6fe
Compare
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; | ||
template <ck_tile::index_t PrefetchStages, typename Dispatcher> | ||
void dispatch_memory_pipeline_hot_loop(const ck_tile::TailNumber tail_num, Dispatcher dispatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a new way of the TailNumber configuration in the CK Tile Engine. You could use that one to decrease the complexity of the code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'd rather skip it for now, as relying on c++ metaprogramming doesn't quite make it simpler
("Row", "Col", "Row"), | ||
] | ||
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] | ||
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are preparing a search space for the ck tile gemm. Do we need to make it dynamic?
|
||
return True | ||
|
||
def filter_op(self, op: "CKTileGemmOperation"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the filter function. We also need to have the following filter feature:
- Filter the total size of the A and B Block Tile should be less than the 64KB.
- Filter that if we split the warps in the certain dimension, we need to make sure that the Block Tile in one dimension is the multiple warp dim (e.g. 2 warps on M/N) * warp size in that dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks!
@tenpercent Thank you for the work Max! Just a few comments on the difference between the current in-progress ck tile codgen and profiler. |
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
max-autotune torch.mm
to make more similar to cutlass and distinguish between ck and ck-tile
Successfully rebased |
ca51f78
to
d36ad37
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds code generation for CK-tile based universal gemm kernels to the CK backend for Inductor, and adds these kernels to autotune choices.
Unlike legacy-CK based kernels (which are generated by parsing the CK instances from CK library), we generate the set of instances by manually specifying the tuning parameters.
This PR introduces a new template for code generation, and compilation/autotuning is handled by the existing infrastructure.
Points of discussion:
** Testing **
Use the existing tests in
test/inductor/test_ck_backend.py
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
cc @zjing14 @coconutruben @ThomasNing @amd-khushbu