-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[MPS] Fix bfloat to complex casts #137070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137070
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ebf87c7 with merge base dfe1d45 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please commit the suggested changes from pytorch's linter.
| if (dstComplex) { | ||
| // TODO: Document why explicit cast is needed only for bfloat types | ||
| if (dtypeSrc == "bfloat") { | ||
| return dtypeDst + "(float(x), 0.0)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return dtypeDst + "(float(x), 0.0)"; | |
| return dtypeDst + "(float(x), 0.0)"; |
| return "bfloat(x)"; | ||
| } | ||
| return "(x)"; | ||
| return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)"; | |
| return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)"; |
For Metal cast ops to comple, one need to explicitly cast to/from `bfloat` unlike for other dtypes Tested in #136987 [ghstack-poisoned]
| const bool srcComplex = dtypeSrc[dtypeSrc.size() - 1] == '2'; | ||
| const bool dstComplex = dtypeDst[dtypeDst.size() - 1] == '2'; | ||
| if (dstComplex) { | ||
| // TODO: Document why explicit cast is needed only for bfloat types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this should be done before merging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copy-n-pasted TODO from another line. I guess I'll need a help from @kulinseth because I could not find anything in the https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf that talks about it
|
@pytorchbot merge -f "Lint + MPS tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
By even further reducing precisions of imprecise FP16 ops, introducing new BF16_LOW_PRECISION_OPS category and marking BF16 tests as xfail for `divfloor_rounding`, `floor_divide` and `remainder`. I guess the nature of low-precision results, is that MPSGraph, unlike the rest of the PyTorch does not do accumulation over fp32 for reduction operations Pull Request resolved: #136987 Approved by: https://github.com/albanD ghstack dependencies: #137070
For Metal cast ops to comple, one need to explicitly cast to/from `bfloat` unlike for other dtypes Tested in pytorch#136987 Pull Request resolved: pytorch#137070 Approved by: https://github.com/Skylion007
By even further reducing precisions of imprecise FP16 ops, introducing new BF16_LOW_PRECISION_OPS category and marking BF16 tests as xfail for `divfloor_rounding`, `floor_divide` and `remainder`. I guess the nature of low-precision results, is that MPSGraph, unlike the rest of the PyTorch does not do accumulation over fp32 for reduction operations Pull Request resolved: pytorch#136987 Approved by: https://github.com/albanD ghstack dependencies: pytorch#137070
Stack from ghstack (oldest at bottom):
For Metal cast ops to comple, one need to explicitly cast to/from
bfloatunlike for other dtypesTested in #136987