KEMBAR78
[ROCm] No-fence global reduce by amd-hhashemi · Pull Request #161180 · pytorch/pytorch · GitHub
Skip to content

Conversation

@amd-hhashemi
Copy link
Contributor

@amd-hhashemi amd-hhashemi commented Aug 21, 2025

This change removes need for fences in global_reduce by converting the stores to reduce_buffer[] into atomics+return. This is crucial for perf in architectures with split caches (e.g. MI300), where fences are inherently costly.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161180

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 59333da with merge base 7376111 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jerrymannil
Copy link
Contributor

This fix provides much better perf that the acquire/release fence solution in #160979
The fence operation is much more heavy weight that atomics with AMD gpus

Reproducer:

import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")

Results (MI300X):

Before:
Avg time for shape (2, 896, 59, 91): 82.13 us

After:
Avg time for shape (2, 896, 59, 91): 44.36 us

jeffdaily
jeffdaily previously approved these changes Aug 21, 2025
// Here we preempt need for fences by committing stores to global memory.
// We do so by converting the stores to atomics with a return.
int constexpr num_int_per_val = sizeof(value)/sizeof(int);
CUDA_KERNEL_ASSERT(num_int_per_val>=1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since num_int_per_val is a constexpr, we can use a static_assert here.

Suggested change
CUDA_KERNEL_ASSERT(num_int_per_val>=1);
static_assert(num_int_per_val>=1);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed that assert, and instead handle small value sizes now.

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Aug 21, 2025
@jeffdaily jeffdaily changed the title No-fence global reduce [ROCm] No-fence global reduce Aug 21, 2025
@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch and removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Aug 21, 2025
@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Aug 22, 2025
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
@amd-hhashemi amd-hhashemi reopened this Aug 25, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Aug 25, 2025
@pytorch-bot pytorch-bot bot dismissed jeffdaily’s stale review August 25, 2025 19:08

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@pytorch-bot pytorch-bot bot removed the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Aug 25, 2025
@pytorchmergebot
Copy link
Collaborator

Successfully rebased no_fnc_glb_rdc onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout no_fnc_glb_rdc && git pull --rebase)

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 26, 2025
@jeffdaily jeffdaily added ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Aug 26, 2025
@jerrymannil
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "unrelated macos build failure; all other CI including ciflow/trunk is passing"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerrymannil added a commit to ROCm/pytorch that referenced this pull request Aug 26, 2025
This change removes need for fences in global_reduce by converting the
stores to reduce_buffer[] into atomics+return. This is crucial for perf
in architectures with split caches (e.g. MI300), where fences are
inherently costly.

cherry-pick of pytorch#161180
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 27, 2025
This change removes need for fences in global_reduce by converting the
stores to reduce_buffer[] into atomics+return. This is crucial for perf
in architectures with split caches (e.g. MI300), where fences are
inherently costly.

cherry-pick of pytorch#161180
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 27, 2025
This change removes need for fences in global_reduce by converting the
stores to reduce_buffer[] into atomics+return. This is crucial for perf
in architectures with split caches (e.g. MI300), where fences are
inherently costly.

cherry-pick of pytorch#161180
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This change removes need for fences in global_reduce by converting the stores to reduce_buffer[] into atomics+return. This is crucial for perf in architectures with split caches (e.g. MI300), where fences are inherently costly.

Pull Request resolved: pytorch#161180
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants