KEMBAR78
[ROCm] Limit number of values per thread for reductions on three dimensions by doru1004 · Pull Request #159652 · pytorch/pytorch · GitHub
Skip to content

Conversation

@doru1004
Copy link
Contributor

@doru1004 doru1004 commented Aug 1, 2025

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@doru1004 doru1004 requested review from eqy and syed-ahmed as code owners August 1, 2025 16:35
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159652

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 70792b5 with merge base 1465757 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Aug 1, 2025
Copy link
Contributor

@petrex petrex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question : Was the choice of 2048 as the threshold for "values per thread" purely heuristic? It would be helpful to add a comment or reference explaining why this value was chosen and whether it is empirically optimal.

Copy link
Contributor

@petrex petrex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question : Is there an upper bound for config.ctas_per_output *= 2;

@jerrymannil
Copy link
Contributor

jerrymannil commented Aug 1, 2025

Reproducer:

import time
import torch

shapes = [(1, 2, 3, 420, 648, 128),
    (1, 2, 3, 420, 648, 128),
    (5079670, 128)
]

dims = [(3,4),
    (-3, -2), 
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(20):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()

    start_time = time.time()
    for _ in range(100):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/100
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
Before
Avg time for shape (1, 2, 3, 420, 648, 128): 4408.10 us
Avg time for shape (1, 2, 3, 420, 648, 128): 4428.89 us
Avg time for shape (5079670, 128): 1458.86 us

After:
Avg time for shape (1, 2, 3, 420, 648, 128): 223.73 us
Avg time for shape (1, 2, 3, 420, 648, 128): 218.85 us
Avg time for shape (5079670, 128): 1461.55 us

@pruthvistony pruthvistony added topic: not user facing topic category rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Aug 1, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

To add the ciflow label ciflow/periodic-rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 label Aug 1, 2025
@pruthvistony pruthvistony added the ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 label Aug 1, 2025
@doru1004
Copy link
Contributor Author

doru1004 commented Aug 4, 2025

Question : Was the choice of 2048 as the threshold for "values per thread" purely heuristic? It would be helpful to add a comment or reference explaining why this value was chosen and whether it is empirically optimal.

It was indeed empirically determined. I'll add a comment.

@doru1004
Copy link
Contributor Author

doru1004 commented Aug 4, 2025

Another question : Is there an upper bound for config.ctas_per_output *= 2;

From the previous semantics there doesn't seem to be the case.

@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 4, 2025
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 4, 2025
@doru1004 doru1004 changed the title [AMDGPU] Limit number of values per thread for reductions on three dimensions [ROCm] Limit number of values per thread for reductions on three dimensions Aug 5, 2025
@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Aug 5, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Aug 6, 2025
…nsions (#2460)

In the current implementation of reductions in three dimensions for AMD
GPUs the number of values per thread is unbounded and can end up being
in the hundreds of thousands for certain tensors. This of course is bad
for performance. This patch fixes this issue by increasing the
parallelism and thus lowering the number of value per thread to
reasonable limits i.e. less than 2048 values per thread. The performance
gains can be between 10x-17x for certain examples where the number of
values per thread was originally very high.

cherry-pick of pytorch#159652
okakarpa pushed a commit to ROCm/pytorch that referenced this pull request Aug 6, 2025
…nsions (#2460)

In the current implementation of reductions in three dimensions for AMD
GPUs the number of values per thread is unbounded and can end up being
in the hundreds of thousands for certain tensors. This of course is bad
for performance. This patch fixes this issue by increasing the
parallelism and thus lowering the number of value per thread to
reasonable limits i.e. less than 2048 values per thread. The performance
gains can be between 10x-17x for certain examples where the number of
values per thread was originally very high.

cherry-pick of pytorch#159652
@jerrymannil
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 12, 2025

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
…nsions (#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: #159652
Approved by: https://github.com/jeffdaily
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 15, 2025
…nsions (pytorch#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: pytorch#159652
Approved by: https://github.com/jeffdaily
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
…nsions (#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: #159652
Approved by: https://github.com/jeffdaily
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…nsions (pytorch#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: pytorch#159652
Approved by: https://github.com/jeffdaily
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Sep 5, 2025
…nsions (pytorch#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: pytorch#159652
Approved by: https://github.com/jeffdaily
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…nsions (pytorch#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: pytorch#159652
Approved by: https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: rocm AMD GPU support for Pytorch open source rocm This tag is for PRs from ROCm team topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants