KEMBAR78
[MPS] Speedup `argmax`/`argmin` by malfet · Pull Request #159524 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Jul 30, 2025

Stack from ghstack (oldest at bottom):

By using efficient threadgroup_arg[max|min] primitives.

  • Fixed bug in simd_argmax when result of the simd_ballot were prematurely cast to ushort and adjusted unit test
  • Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to bench_mps_ops.py max(x, dim=0) is reliably faster than eager implementaiton:

[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1     
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8     
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5     
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8     

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159524

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 72 Pending

As of commit adf1276 with merge base 25343b3 (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) module: inductor labels Jul 30, 2025
malfet added a commit that referenced this pull request Jul 30, 2025
By using efficient `threadgroup_arg[max|min]` primitives

ghstack-source-id: 134819e
Pull Request resolved: #159524
@malfet malfet requested a review from dcci July 30, 2025 22:57
@malfet malfet added the topic: improvements topic category label Jul 30, 2025
@Skylion007 Skylion007 changed the title [MPS] Speedup argmax/armin [MPS] Speedup argmax/argmin Jul 31, 2025
[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner July 31, 2025 15:43
malfet added a commit that referenced this pull request Jul 31, 2025
By using efficient `threadgroup_arg[max|min]` primitives

ghstack-source-id: bcdd0a1
Pull Request resolved: #159524
@malfet malfet added the release notes: mps Release notes category label Jul 31, 2025
@malfet
Copy link
Contributor Author

malfet commented Jul 31, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: #159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
@github-actions github-actions bot deleted the gh/malfet/462/head branch August 31, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) Merged module: inductor release notes: mps Release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants