KEMBAR78
[ROCm] Ck backend UX refactor by alugorey · Pull Request #152951 · pytorch/pytorch · GitHub
Skip to content

Conversation

@alugorey
Copy link
Contributor

@alugorey alugorey commented May 6, 2025

Refactors how the enablement/disablement of CK Gemms and SDPA works.

  • Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms.
  • USE_ROCM_CK_GEMM is set to True by default on Linux
  • Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA.
  • USE_ROCM_CK_SDPA is set to False by default
  • (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release)
  • Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it.
  • the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @mcarilli @ptrblck @leslie-fang-intel @EikanWang @voznesenskym @penguinwu @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @xmfan

@alugorey alugorey requested review from eqy and syed-ahmed as code owners May 6, 2025 14:53
@pytorch-bot
Copy link

pytorch-bot bot commented May 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152951

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 11 Pending

As of commit f95d117 with merge base f3a4d74 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased ck_gemm_guard onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout ck_gemm_guard && git pull --rebase)

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jun 16, 2025
@alugorey
Copy link
Contributor Author

After some internal discussion, we are refactoring some of this code to provide a better UX. Working on the implementation now.

@alugorey alugorey changed the title [ROCm] Ck gemm architecture guard [ROCm] Ck backend UX refactor Jul 14, 2025
@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jul 14, 2025
@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jul 14, 2025
@alugorey
Copy link
Contributor Author

Looking through CI failures now. I missed some files in the commit causing the errors. fixing now.

@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jul 15, 2025
@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jul 21, 2025
Copy link
Collaborator

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, pending my question/suggestion.

@jithunnair-amd jithunnair-amd added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed release notes: distributed (checkpoint) module: compiled autograd compiled_autograd release notes: inductor (aoti) labels Aug 7, 2025
@alugorey
Copy link
Contributor Author

alugorey commented Aug 8, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jithunnair-amd
Copy link
Collaborator

@pytorchbot merge -f "All ROCm CI/lint/pull workflow runs passed. Force merging to prevent PR getting stale again, plus multiple dependent PRs in queue"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@jithunnair-amd
Copy link
Collaborator

@pytorchbot merge -f "All ROCm CI/lint/pull workflow runs passed. Force merging to prevent PR getting stale again, plus multiple dependent PRs in queue"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Refactors how the enablement/disablement of CK Gemms and SDPA works.

- Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms.
- USE_ROCM_CK_GEMM is set to True by default on Linux
- Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA.
- USE_ROCM_CK_SDPA is set to False by default
- (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release)
- Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it.
- the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default

Pull Request resolved: pytorch#152951
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Comment on lines +471 to +474
if(WIN32) # Windows doesn't support Composable Kernels and Triton
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
endif()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm observing link errors downstream building torch on Windows with ROCm that are resolved by removing these four lines and keeping those sources. If accelerated versions of the transfomers/ functions are not available, the native versions should still be included.

I'm building with these options, though we are working on getting aotriton enabled for flash attention (ROCm/TheRock#1207):

--   USE_ROCM              : ON
--     ROCM_VERSION          : 
--     USE_FLASH_ATTENTION   : OFF
--     USE_MEM_EFF_ATTENTION : 0
--     USE_ROCM_CK_SDPA      : OFF
--     USE_ROCM_CK_GEMM      : OFF

Logs snippet:

[7019/7028] Linking CXX shared library bin\torch_hip.dll
FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib 
C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests  -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp  /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64)
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64))
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64))

More logs: https://gist.github.com/ScottTodd/195d7ff3e5d6a6480ef6289537629c7c

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sent a "revert" of those 4 lines here: #160373

ScottTodd added a commit to ROCm/TheRock that referenced this pull request Aug 11, 2025
…`. (#1227)

## Motivation

Windows nightly PyTorch builds are currently broken:
https://github.com/ROCm/TheRock/actions/workflows/release_windows_pytorch_wheels.yml?query=branch%3Amain.
This fixes them.

## Technical Details

Builds without `--enable-pytorch-flash-attention-windows` were failing
with `error: use of undeclared identifier 'mha_fwd_aot'`, but we also
needed to remove the `exclude(ATen_HIP_SRCS` code that was added
upstream to resolve `lld-link: error: undefined symbol ...
transform_bias_rescale_qkv_cuda`.

This is a minimal fix-forward for the first issue and restructures the
patches for the second issue.

Context:
* #1207 (comment)
* pytorch/pytorch#152951 (comment)

## Test Plan

Built locally with `python build_prod_wheels.py build --pytorch-dir
D:/b/pytorch_main --pytorch-audio-dir D:/b/audio_main
--pytorch-vision-dir D:/b/vision_main --output-dir
%HOME%/.therock/pytorch --no-enable-pytorch-flash-attention-windows`

## Test Result

Local build succeeded.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
TaiXeflar pushed a commit to TaiXeflar/TheRock that referenced this pull request Aug 17, 2025
…`. (ROCm#1227)

## Motivation

Windows nightly PyTorch builds are currently broken:
https://github.com/ROCm/TheRock/actions/workflows/release_windows_pytorch_wheels.yml?query=branch%3Amain.
This fixes them.

## Technical Details

Builds without `--enable-pytorch-flash-attention-windows` were failing
with `error: use of undeclared identifier 'mha_fwd_aot'`, but we also
needed to remove the `exclude(ATen_HIP_SRCS` code that was added
upstream to resolve `lld-link: error: undefined symbol ...
transform_bias_rescale_qkv_cuda`.

This is a minimal fix-forward for the first issue and restructures the
patches for the second issue.

Context:
* ROCm#1207 (comment)
* pytorch/pytorch#152951 (comment)

## Test Plan

Built locally with `python build_prod_wheels.py build --pytorch-dir
D:/b/pytorch_main --pytorch-audio-dir D:/b/audio_main
--pytorch-vision-dir D:/b/vision_main --output-dir
%HOME%/.therock/pytorch --no-enable-pytorch-flash-attention-windows`

## Test Result

Local build succeeded.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
pytorchmergebot pushed a commit that referenced this pull request Aug 22, 2025
…160373)

Following up on #152951 (comment), this removes a few lines added in that pull request, fixing link errors like
```
[7019/7028] Linking CXX shared library bin\torch_hip.dll
FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib
C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests  -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp  /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64)
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64))
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64))
```

The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available.

I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK.

Pull Request resolved: #160373
Approved by: https://github.com/alugorey, https://github.com/jeffdaily
wincent8 pushed a commit to wincent8/pytorch that referenced this pull request Aug 22, 2025
…ytorch#160373)

Following up on pytorch#152951 (comment), this removes a few lines added in that pull request, fixing link errors like
```
[7019/7028] Linking CXX shared library bin\torch_hip.dll
FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib
C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests  -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp  /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64)
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64))
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64))
```

The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available.

I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK.

Pull Request resolved: pytorch#160373
Approved by: https://github.com/alugorey, https://github.com/jeffdaily
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Refactors how the enablement/disablement of CK Gemms and SDPA works.

- Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms.
- USE_ROCM_CK_GEMM is set to True by default on Linux
- Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA.
- USE_ROCM_CK_SDPA is set to False by default
- (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release)
- Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it.
- the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default

Pull Request resolved: pytorch#152951
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ytorch#160373)

Following up on pytorch#152951 (comment), this removes a few lines added in that pull request, fixing link errors like
```
[7019/7028] Linking CXX shared library bin\torch_hip.dll
FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib
C:\Windows\system32\cmd.exe /C "cd . && D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1942 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100261~1.0\x64\rc.exe --mt=C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\Llvm\x64\bin\llvm-mt.exe --manifests  -- D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp  /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
LINK: command "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
lld-link: error: undefined symbol: __declspec(dllimport) class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::native::transform_bias_rescale_qkv_cuda(class at::Tensor const &, class at::Tensor const &, __int64)
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_CUDA___transform_bias_rescale_qkv(class 0xE9BF7323::Tensor const &, class 0xE9BF7323::Tensor const &, __int64))
>>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\RegisterNestedTensorCUDA_0.cpp.obj:(class std::tuple<class at::Tensor, class at::Tensor, class at::Tensor> __cdecl at::`anonymous namespace'::`anonymous namespace'::wrapper_NestedTensorCUDA___transform_bias_rescale_qkv(class 0xEFEB5304::Tensor const &, class 0xEFEB5304::Tensor const &, __int64))
```

The `native_transformers_hip_hip` and `native_transformers_hip_cpp` sources are okay to define (and are required) even if accelerated versions of these operations are not available.

I've tested downstream builds of torch with ROCm on native Windows via https://github.com/ROCm/TheRock both with and without aotriton and these changes were needed for the build to succeed in both cases. I have _not_ tested Linux, WSL, or with the HIP SDK.

Pull Request resolved: pytorch#160373
Approved by: https://github.com/alugorey, https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants