-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[MPS] Speedup argmax/argmin
#159524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Speedup argmax/argmin
#159524
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159524
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 72 PendingAs of commit adf1276 with merge base 25343b3 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy
Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[--------------------------------------------------------------------------------------------- --------------------------------------------------------------------------------------------]
| eager-512x512 | compile-512x512 | eager-1024x1024 | compile-1024x1024 | eager-2048x2048 | compile-2048x2048 | eager-4096x4096 | compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
max (torch.float16) | 285.8 | 272.2 | 422.3 | 354.5 | 721.6 | 683.5 | 2224.0 | 1979.1
max (torch.float32) | 300.2 | 267.0 | 389.6 | 342.5 | 769.4 | 682.6 | 2995.7 | 2609.8
max (torch.int32) | 299.6 | 275.4 | 390.0 | 361.7 | 758.7 | 686.1 | 3103.4 | 2646.5
max (torch.int64) | 297.5 | 275.5 | 417.0 | 382.1 | 856.1 | 722.6 | 5467.7 | 3156.8
```
Pull Request resolved: #159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
Stack from ghstack (oldest at bottom):
argmax/argmin#159524By using efficient
threadgroup_arg[max|min]primitives.simd_argmaxwhen result of thesimd_ballotwere prematurely cast toushortand adjusted unit testNow according to
bench_mps_ops.pymax(x, dim=0)is reliably faster than eager implementaiton:cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben