KEMBAR78
[ROCm] slow torch.sum optimization by increasing max_values_per_thread in reduce config by hongxiayang · Pull Request #135397 · pytorch/pytorch · GitHub
Skip to content

Conversation

@hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented Sep 6, 2024

Fixes #132964

This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform.
By increasing this parameter, it uses fewer threadblocks and improved the performance for large tensors.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s).

import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)

Co-author: @carlobertolli

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135397

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 6cd7c04 with merge base de74aaf (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/rocm Trigger "default" config CI on ROCm module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Sep 6, 2024
@hongxiayang hongxiayang changed the title [ROCm] slow tensor sum optimization by increasing max_values_per_thread in reduce config [ROCm] slow torch.sum optimization by increasing max_values_per_thread in reduce config Sep 6, 2024
@hongxiayang hongxiayang marked this pull request as ready for review September 6, 2024 22:52
@hongxiayang hongxiayang requested a review from malfet September 9, 2024 21:39
@hongxiayang
Copy link
Collaborator Author

HI, @malfet : Can you help to merge this PR? The two test failures are not related. Thank you!

@jithunnair-amd jithunnair-amd added the rocm This tag is for PRs from ROCm team label Sep 10, 2024
@malfet
Copy link
Contributor

malfet commented Sep 10, 2024

@pytorchbot merge -f "Lint is green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jithunnair-amd
Copy link
Collaborator

@pytorchmergebot cherry-pick --help

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 10, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot cherry-pick: error: the following arguments are required: --onto, -c/--classification

usage: @pytorchbot cherry-pick --onto ONTO [--fixes FIXES] -c
                               {regression,critical,fixnewfeature,docs,release}

Try @pytorchbot --help for more info.

@jithunnair-amd
Copy link
Collaborator

@pytorchmergebot cherry-pick --onto release/2.5 -c critical

pytorchbot pushed a commit that referenced this pull request Sep 10, 2024
…d in reduce config (#135397)

Fixes #132964

This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform.
By increasing this parameter, it uses fewer threadblocks and improved the performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: #135397
Approved by: https://github.com/eqy, https://github.com/malfet

(cherry picked from commit eb38ee2)
@pytorchbot
Copy link
Collaborator

Cherry picking #135397

The cherry pick PR is at #135624 and it is recommended to link a critical cherry pick PR with an issue.

Details for Dev Infra team Raised by workflow job

yushangdi pushed a commit that referenced this pull request Sep 12, 2024
…d in reduce config (#135397)

Fixes #132964

This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform.
By increasing this parameter, it uses fewer threadblocks and improved the performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: #135397
Approved by: https://github.com/eqy, https://github.com/malfet
hongxiayang added a commit to ROCm/pytorch that referenced this pull request Sep 12, 2024
#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <hongxyan@amd.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Sep 13, 2024
)

Follow-up to pytorch#135397.
AMD gpus perform better with fewer thread blocks.
So increase the min_values_per_thread as well. 
This helped improved
[CvT](https://github.com/facebookresearch/FAMBench/tree/main/benchmarks/cvt)
benchmark performance on MI300X

Co-author: @carlobertolli
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform.
By increasing this parameter, it uses fewer threadblocks and improved the performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet
@functionstackx
Copy link
Contributor

thanks @hongxiayang ! i can confirm that this fixes it

jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Oct 23, 2024
#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <hongxyan@amd.com>
(cherry picked from commit 4360582)
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Oct 23, 2024
)

Follow-up to pytorch#135397.
AMD gpus perform better with fewer thread blocks.
So increase the min_values_per_thread as well.
This helped improved
[CvT](https://github.com/facebookresearch/FAMBench/tree/main/benchmarks/cvt)
benchmark performance on MI300X

Co-author: @carlobertolli
(cherry picked from commit c1b6f60)
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Mar 17, 2025
#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <hongxyan@amd.com>
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Mar 17, 2025
)

Follow-up to pytorch#135397.
AMD gpus perform better with fewer thread blocks.
So increase the min_values_per_thread as well. 
This helped improved
[CvT](https://github.com/facebookresearch/FAMBench/tree/main/benchmarks/cvt)
benchmark performance on MI300X

Co-author: @carlobertolli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm Merged module: rocm AMD GPU support for Pytorch open source release notes: cuda release notes category rocm This tag is for PRs from ROCm team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ROCm MI300X sum() way slower than H100

7 participants