-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[MPS] Extend addmm to integral types #160270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes #154901 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160270
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 22 PendingAs of commit cdfaf5a with merge base 86eb65f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| } | ||
| }); | ||
| return output; | ||
| return output; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double return? Surprised this didn't error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I hoped some of the linters will be triggered by it, but feels like this is fine...
| constant uint3& sizes [[buffer(6)]], | ||
| uint2 tid [[thread_position_in_threadgroup]], | ||
| uint2 thread_id [[thread_position_in_grid]]) { | ||
| threadgroup T A_tile[TILE_DIM][TILE_DIM]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ugly, but can this be rewritten as an std array too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threadgroups are a bit weird(i.e. this statement affects GPU occupancy), let me give it a try in a separate PR, but make sure it would not regress the perf...
Fixes #154901 [ghstack-poisoned]
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
By adding `addmm` kernel, which is a logical continuation of `mm` one. The only tricking part are how alpha and beta constants are handled, which are passed as `optmath_t`, i.e. that it could be, int64, int32 or float Unified all MM flavors instantiations thru `INSTANTIATE_MM_OPS` and tested that `addmm` metal kernel works as expected for floating types as well by testing it via ``` PYTORCH_MPS_PREFER_METAL=1 python test/test_mps.py -v -k test_output_match_addmm_mps_ ``` Fixes pytorch#154901 Pull Request resolved: pytorch#160270 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: pytorch#160228, pytorch#160234
Stack from ghstack (oldest at bottom):
By adding
addmmkernel, which is a logical continuation ofmmone. The only tricking part are how alpha and beta constants are handled, which are passed asoptmath_t, i.e. that it could be, int64, int32 or floatUnified all MM flavors instantiations thru
INSTANTIATE_MM_OPSand tested thataddmmmetal kernel works as expected for floating types as well by testing it viaFixes #154901