KEMBAR78
[ROCm] logsumexp on ROCm needs scaling back to natural base. by xinyazhang · Pull Request #156903 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Jun 26, 2025

Fixes #156012

This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton.

After discussion we are not going to release AOTriton 0.10.1 to fix this due to

  • Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release.
  • AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156903

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit cfa0de7 with merge base 9894d43 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 26, 2025
@xinyazhang xinyazhang requested review from XilunWu and fegin and removed request for fegin June 26, 2025 00:04
@xinyazhang
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jun 26, 2025
@functionstackx
Copy link
Contributor

thanks for the fix @xinyazhang

can an unit tests can be added to previous regression in the future?

probably the OG reprod script would be a good unit test as it doesnt take that much time to run #156012

@xinyazhang
Copy link
Collaborator Author

can an unit tests can be added to previous regression in the future?

If you mean the logsumexp tensor's behavior alignment with CUTLASS backend, it will be part of AOTriton 0.11 integration PR.

We need to test the behavior change in AOTriton's own UT first.

@xinyazhang xinyazhang marked this pull request as ready for review June 26, 2025 20:28
@xinyazhang
Copy link
Collaborator Author

cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable) is unstable ATM

@functionstackx
Copy link
Contributor

can an unit tests can be added to previous regression in the future?

If you mean the logsumexp tensor's behavior alignment with CUTLASS backend, it will be part of AOTriton 0.11 integration PR.

I think i was more pointing at that a general unit test that context parallel sdpa has the same numerics as single gpu sdpa for both nvidia & amd

jeffdaily
jeffdaily previously approved these changes Jun 30, 2025
@jeffdaily jeffdaily changed the title [ROCM] logsumexp on ROCM needs scaling back to natural base. [ROCm] logsumexp on ROCm needs scaling back to natural base. Jun 30, 2025
@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm labels Jun 30, 2025
XilunWu
XilunWu previously approved these changes Jul 2, 2025
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall look good to me!

Comment on lines 571 to 473
need_scaling = True
# Note: it is possible that CK is seleted but not compiled in the binary.
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
# Unsure about CK's behavior, keep logsumexp untouched
need_scaling = False
if need_scaling:
logsumexp *= 0.6931471805599453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this equivalent to:

if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
    logsumexp *= 0.6931471805599453

Copy link
Collaborator Author

@xinyazhang xinyazhang Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not(_is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND):
    logsumexp *= 0.6931471805599453

This is the equivalent

Copy link
Collaborator Author

@xinyazhang xinyazhang Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the more verbose form to make the logic easier to read.

xinyazhang added a commit to ROCm/aotriton that referenced this pull request Jul 11, 2025
# Overview 

Previously we were using 2-based logsumexp (L) tensor b/w forward and backward
passes to eliminate unnecessary converts.

However this causes quite a few problems:

* PyTorch's Context Parallelism system requires natural based (e-based) L
  tensor
  + See pytorch/pytorch#156012 for the bug report and
    pytorch/pytorch#156903 for a temporary solution.
* AITER ASM backward kernel uses natural based L tensor

# Major Changes

* [kernel] Return natural based L tensor in forward kernel, and translate to
  2-based in backward kernel when loading
* [test] Add `test_logsumexp_scaling` to confirm the scaling is correct.
* [build] Set `TRITON_STORE_BINARY_ONLY=1` to avoid caching intermediate files.
  This massively reduces the size of `triton-cache` directory
* [compiler] Bump to the latest Triton compiler to avoid the updated kernel
  causing GPU segment fault in UT
  `Split-False-l1-dtype2-0.5-CausalOff-64-64-hdim160-5-3` on MI300X
@jithunnair-amd
Copy link
Collaborator

@pytorchbot merge -f "CI failures unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@xinyazhang xinyazhang force-pushed the xinyazhang/issue-156012 branch from 4a2fc29 to 21e735c Compare July 21, 2025 17:00
@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm labels Jul 21, 2025
@xinyazhang xinyazhang requested a review from jeffdaily July 21, 2025 17:00
@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm labels Jul 22, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 22, 2025

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 22, 2025

To add the ciflow label ciflow/inductor please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor labels Jul 22, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xinyazhang added a commit to ROCm/pytorch that referenced this pull request Aug 28, 2025
xinyazhang added a commit to ROCm/pytorch that referenced this pull request Aug 29, 2025
pytorchmergebot pushed a commit that referenced this pull request Sep 3, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR #156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: #161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
xinyazhang added a commit to ROCm/pytorch that referenced this pull request Sep 29, 2025
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR pytorch#156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: pytorch#161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch oncall: distributed Add this issue/PR to distributed oncall triage queue open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] BF16 Context Parallelism MI300X Not Numerically Accurate