KEMBAR78
[ROCm] Bump AOTriton to 0.11b by xinyazhang · Pull Request #161754 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Aug 29, 2025

Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

  • AOTriton 0.11 packs GPU images separately from AOTriton runtime
  • aotriton.cmake now selectively downloads image packs according to
    PYTORCH_ROCM_ARCH
  • aotriton.cmake now only use pre-compiled runtime library that exactly
    matches the ROCM in the build environment. For PyTorch builds with ROCm
    versions not listed in the file, the build process will build AOTriton
    runtime without GPU images from source
    • This avoids any further ABI breaks like ROCM 6.4 -> 7.0
    • recursive git clone is disabled since building AOTriton runtime does not
      require submodules.

Bug fixes:

  • Fix a kernel bug introduced when implementing SWA

Known Problems:

  • gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
    due to accuracy issues. Triton compiler fixes are needed to restore the
    support status.
  • Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
    This issue is under investigation.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161754

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 008f831 with merge base 403a3a3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 29, 2025
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Sep 3, 2025
@jithunnair-amd
Copy link
Collaborator

jithunnair-amd commented Sep 3, 2025

rocm-mi300 workflow passed (except for one unrelated timeout): https://hud.pytorch.org/pytorch/pytorch/pull/161754?sha=0add8c2ad827f5562b94b22a1e17aa3d5092951d#rocm-mi300

rocm workflow passed: https://hud.pytorch.org/pytorch/pytorch/pull/161754?sha=0add8c2ad827f5562b94b22a1e17aa3d5092951d#rocm

trunk passed: https://hud.pytorch.org/pytorch/pytorch/pull/161754?sha=0add8c2ad827f5562b94b22a1e17aa3d5092951d#trunk

Latest commit 008f831 merely moves the message, so test results from previous commit should be good enough. Merging to allow sufficient time for internal builds to adjust to aotriton 0.11b changes.

@jithunnair-amd
Copy link
Collaborator

@jeffdaily please approve and merge

@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

CMAKE_CACHE_ARGS
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
-DCMAKE_INSTALL_PREFIX:FILEPATH=${__AOTRITON_INSTALL_DIR}
CMAKE_ARGS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aotriton_build_from_source function is at least going to need -DHIP_PLATFORM=amd set as well, to avoid errors on Linux and Windows like https://github.com/ROCm/TheRock/actions/runs/17467234227/job/49606011053#step:11:66811

 [1902/7918] Performing configure step for 'aotriton_runtime'
CMake Error at /opt/python/cp313-cp313/lib/python3.13/site-packages/_rocm_sdk_devel/lib/cmake/hip/hip-config.cmake:144 (message):
  Unexpected HIP_PLATFORM:
Call Stack (most recent call first):
  CMakeLists.txt:64 (find_package)

The source for that hip-config.cmake is https://github.com/ROCm/rocm-systems/blob/2202dcfe806766804648a9f38de35f555351e7fa/projects/clr/hipamd/hip-config.cmake.in#L111-L121

if(HIP_PLATFORM STREQUAL "amd")
  set(HIP_RUNTIME "rocclr")
  set(HIP_COMPILER "clang")
  include( "${hip_LIB_INSTALL_DIR}/cmake/hip/hip-config-amd.cmake" )
elseif(HIP_PLATFORM STREQUAL "nvidia")
  set(HIP_RUNTIME "cuda")
  set(HIP_COMPILER "nvcc")
  include( "${hip_LIB_INSTALL_DIR}/cmake/hip/hip-config-nvidia.cmake" )
else()
  message(FATAL_ERROR "Unexpected HIP_PLATFORM: " ${HIP_PLATFORM})
endif()

I'm working on rebasing TheRock's downstream patch https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/patches/pytorch/main/pytorch/hipified/0002-Support-FLASH_ATTENTION-MEM_EFF_ATTENTION-via.-aotri.patch . I'll try to split it into baseline Linux fixes like that one and the deeper changes needed for Windows support.

Copy link
Collaborator Author

@xinyazhang xinyazhang Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real problem is

execute_process(COMMAND ${hip_HIPCONFIG_EXECUTABLE} --platform
      OUTPUT_VARIABLE HIP_PLATFORM
      OUTPUT_STRIP_TRAILING_WHITESPACE)

does not work on Windows as expected. Normally setting HIP_PLATFORM is not necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally setting HIP_PLATFORM is not necessary

In addition, it's not due to missing GPU installation.
I checked with 7.0RC2 docker image. CPU only instances can correctly report the platform as amd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh... we're seeing issues on both Linux and Windows there. I'll debug a bit. I see what you're saying - the logic in there should infer the HIP_PLATFORM CMake variable via hipconfig --platform.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interesting part is HIP_PLATFORM becomes mandatory even on Linux. From my experience this variable is always auto-configured.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦

if("@HIP_INSTALLS_HIPCC@")

is getting templated as this in TheRock's builds:

if("OFF")
  if (WIN32)
    set_and_check(hip_HIPCC_EXECUTABLE "${hip_BIN_INSTALL_DIR}/hipcc.exe")
    set_and_check(hip_HIPCONFIG_EXECUTABLE "${hip_BIN_INSTALL_DIR}/hipconfig.exe")
  else()
    set_and_check(hip_HIPCC_EXECUTABLE "${hip_BIN_INSTALL_DIR}/hipcc")
    set_and_check(hip_HIPCONFIG_EXECUTABLE "${hip_BIN_INSTALL_DIR}/hipconfig")
  endif()
endif()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed ROCm/TheRock#1402 to solve that. What you have here (omitting -DHIP_PLATFORM=amd) should be fine then.

We have other issues downstream to triage though: ROCm/TheRock#1401 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're investigating that and other failures on Linux when building using TheRock at ROCm/TheRock#1408

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
xinyazhang added a commit to ROCm/pytorch that referenced this pull request Sep 29, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Oct 7, 2025
Fixes: pytorch#163958

Cherry-pick pytorch#161754
Cherry-pick pytorch#162330
Cherry-pick pytorch#163373
Cherry-pick pytorch#163745

Note TF32 support is still being plagued by `HIPBLASLT_ALLOW_TF32`,
which should be handled by another PR due to its complexity.

---------

Co-authored-by: Aaryaman Vasishta <aaryaman.vasishta@amd.com>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: inductor module: rocm AMD GPU support for Pytorch oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: rocm mandatorylabel topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants