KEMBAR78
[ROCm] Improve reduction sum performance by jerrymannil · Pull Request #160466 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jerrymannil
Copy link
Contributor

@jerrymannil jerrymannil commented Aug 12, 2025

  • Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

Reproducer:

import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")

Before (MI300X):
Avg time for shape (5079670, 128): 1629.99 us

After (MI300X)
Avg time for shape (5079670, 128): 1008.59 us

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

* Use input vectorization for reduction_on_fastest_striding_dimension when dim >= 0
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160466

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bff4a8a with merge base 4d419a7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Aug 12, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Aug 12, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466
Copy link
Contributor

@petrex petrex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm Trigger "default" config CI on ROCm and removed release notes: cuda release notes category labels Aug 12, 2025
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 13, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension
when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerrymannil added a commit to ROCm/pytorch that referenced this pull request Aug 13, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
@jerrymannil jerrymannil deleted the patch-1 branch August 13, 2025 18:49
dhonnappa-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 13, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension
when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

Pull Request resolved: #160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 15, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

Pull Request resolved: #160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

Pull Request resolved: pytorch#160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Sep 5, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

Pull Request resolved: pytorch#160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants