KEMBAR78
[ROCm] Unroll loads in global_reduce by amd-hhashemi · Pull Request #161181 · pytorch/pytorch · GitHub
Skip to content

Conversation

@amd-hhashemi
Copy link
Contributor

@amd-hhashemi amd-hhashemi commented Aug 21, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161181

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4eb3f62 with merge base 34358f3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Aug 21, 2025
@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Aug 21, 2025
@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Aug 21, 2025
@jerrymannil
Copy link
Contributor

Reproducer:

import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")

Results (MI300X):

Before:
Avg time for shape (2, 896, 59, 91): 82.13 us

After:
Avg time for shape (2, 896, 59, 91): 72.47 us

@jeffdaily jeffdaily changed the title Unroll loads in global_reduce [ROCm] Unroll loads in global_reduce Aug 21, 2025
@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch ciflow/rocm Trigger "default" config CI on ROCm labels Aug 21, 2025
@jeffdaily jeffdaily added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Aug 21, 2025
@pytorch-bot pytorch-bot bot removed ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 21, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Aug 21, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Aug 21, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Aug 21, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Aug 21, 2025
@jeffdaily jeffdaily added the ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 label Aug 21, 2025
@pruthvistony
Copy link
Collaborator

@amd-hhashemi ,
Please check the lint failure.

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 22, 2025
@jeffdaily jeffdaily added ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 22, 2025
@jeffdaily
Copy link
Collaborator

@amd-hhashemi , Please check the lint failure.

I fixed it.

jerrymannil added a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Pull Request resolved: pytorch#161181
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants