-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ROCm] Ck backend UX refactor #152951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Ck backend UX refactor #152951
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152951
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 11 PendingAs of commit f95d117 with merge base f3a4d74 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
8df618b to
77cd6e8
Compare
|
After some internal discussion, we are refactoring some of this code to provide a better UX. Working on the implementation now. |
|
Looking through CI failures now. I missed some files in the commit causing the errors. fixing now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, pending my question/suggestion.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot merge -f "All ROCm CI/lint/pull workflow runs passed. Force merging to prevent PR getting stale again, plus multiple dependent PRs in queue" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge -f "All ROCm CI/lint/pull workflow runs passed. Force merging to prevent PR getting stale again, plus multiple dependent PRs in queue" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Refactors how the enablement/disablement of CK Gemms and SDPA works. - Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms. - USE_ROCM_CK_GEMM is set to True by default on Linux - Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA. - USE_ROCM_CK_SDPA is set to False by default - (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release) - Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it. - the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default Pull Request resolved: pytorch#152951 Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Jithun Nair <jithun.nair@amd.com> Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
| if(WIN32) # Windows doesn't support Composable Kernels and Triton | ||
| exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" | ||
| ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) | ||
| endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm observing link errors downstream building torch on Windows with ROCm that are resolved by removing these four lines and keeping those sources. If accelerated versions of the transfomers/ functions are not available, the native versions should still be included.
I'm building with these options, though we are working on getting aotriton enabled for flash attention (ROCm/TheRock#1207):
-- USE_ROCM : ON
-- ROCM_VERSION :
-- USE_FLASH_ATTENTION : OFF
-- USE_MEM_EFF_ATTENTION : 0
-- USE_ROCM_CK_SDPA : OFF
-- USE_ROCM_CK_GEMM : OFF
Logs snippet:
[7019/7028] Linking CXX shared library bin\torch_hip.dll
FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib
C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64)
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64))
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64))
More logs: https://gist.github.com/ScottTodd/195d7ff3e5d6a6480ef6289537629c7c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I sent a "revert" of those 4 lines here: #160373
…`. (#1227) ## Motivation Windows nightly PyTorch builds are currently broken: https://github.com/ROCm/TheRock/actions/workflows/release_windows_pytorch_wheels.yml?query=branch%3Amain. This fixes them. ## Technical Details Builds without `--enable-pytorch-flash-attention-windows` were failing with `error: use of undeclared identifier 'mha_fwd_aot'`, but we also needed to remove the `exclude(ATen_HIP_SRCS` code that was added upstream to resolve `lld-link: error: undefined symbol ... transform_bias_rescale_qkv_cuda`. This is a minimal fix-forward for the first issue and restructures the patches for the second issue. Context: * #1207 (comment) * pytorch/pytorch#152951 (comment) ## Test Plan Built locally with `python build_prod_wheels.py build --pytorch-dir D:/b/pytorch_main --pytorch-audio-dir D:/b/audio_main --pytorch-vision-dir D:/b/vision_main --output-dir %HOME%/.therock/pytorch --no-enable-pytorch-flash-attention-windows` ## Test Result Local build succeeded. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
…`. (ROCm#1227) ## Motivation Windows nightly PyTorch builds are currently broken: https://github.com/ROCm/TheRock/actions/workflows/release_windows_pytorch_wheels.yml?query=branch%3Amain. This fixes them. ## Technical Details Builds without `--enable-pytorch-flash-attention-windows` were failing with `error: use of undeclared identifier 'mha_fwd_aot'`, but we also needed to remove the `exclude(ATen_HIP_SRCS` code that was added upstream to resolve `lld-link: error: undefined symbol ... transform_bias_rescale_qkv_cuda`. This is a minimal fix-forward for the first issue and restructures the patches for the second issue. Context: * ROCm#1207 (comment) * pytorch/pytorch#152951 (comment) ## Test Plan Built locally with `python build_prod_wheels.py build --pytorch-dir D:/b/pytorch_main --pytorch-audio-dir D:/b/audio_main --pytorch-vision-dir D:/b/vision_main --output-dir %HOME%/.therock/pytorch --no-enable-pytorch-flash-attention-windows` ## Test Result Local build succeeded. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
…160373) Following up on #152951 (comment), this removes a few lines added in that pull request, fixing link errors like ``` [7019/7028] Linking CXX shared library bin\torch_hip.dll FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ." LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output: lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64)) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64)) ``` The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available. I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK. Pull Request resolved: #160373 Approved by: https://github.com/alugorey, https://github.com/jeffdaily
…ytorch#160373) Following up on pytorch#152951 (comment), this removes a few lines added in that pull request, fixing link errors like ``` [7019/7028] Linking CXX shared library bin\torch_hip.dll FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ." LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output: lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64)) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64)) ``` The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available. I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK. Pull Request resolved: pytorch#160373 Approved by: https://github.com/alugorey, https://github.com/jeffdaily
Refactors how the enablement/disablement of CK Gemms and SDPA works. - Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms. - USE_ROCM_CK_GEMM is set to True by default on Linux - Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA. - USE_ROCM_CK_SDPA is set to False by default - (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release) - Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it. - the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default Pull Request resolved: pytorch#152951 Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Jithun Nair <jithun.nair@amd.com> Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
…ytorch#160373) Following up on pytorch#152951 (comment), this removes a few lines added in that pull request, fixing link errors like ``` [7019/7028] Linking CXX shared library bin\torch_hip.dll FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ." LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output: lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64)) >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64)) ``` The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available. I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK. Pull Request resolved: pytorch#160373 Approved by: https://github.com/alugorey, https://github.com/jeffdaily
Refactors how the enablement/disablement of CK Gemms and SDPA works.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @mcarilli @ptrblck @leslie-fang-intel @EikanWang @voznesenskym @penguinwu @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @xmfan