KEMBAR78
[MPS] Fix bfloat to complex casts by malfet · Pull Request #137070 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Oct 1, 2024

Stack from ghstack (oldest at bottom):

For Metal cast ops to comple, one need to explicitly cast to/from bfloat unlike for other dtypes

Tested in #136987

@malfet malfet requested a review from kulinseth as a code owner October 1, 2024 01:49
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137070

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ebf87c7 with merge base dfe1d45 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 1, 2024
@malfet malfet added the topic: improvements topic category label Oct 1, 2024
@malfet malfet requested review from Skylion007 and albanD October 1, 2024 01:51
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please commit the suggested changes from pytorch's linter.

if (dstComplex) {
// TODO: Document why explicit cast is needed only for bfloat types
if (dtypeSrc == "bfloat") {
return dtypeDst + "(float(x), 0.0)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return dtypeDst + "(float(x), 0.0)";
return dtypeDst + "(float(x), 0.0)";

return "bfloat(x)";
}
return "(x)";
return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)";
return dtypeSrc == "bfloat" ? dtypeDst + "(x)" : "(x)";


For Metal cast ops to comple, one need to explicitly cast to/from `bfloat` unlike for other dtypes

Tested in #136987

[ghstack-poisoned]
const bool srcComplex = dtypeSrc[dtypeSrc.size() - 1] == '2';
const bool dstComplex = dtypeDst[dtypeDst.size() - 1] == '2';
if (dstComplex) {
// TODO: Document why explicit cast is needed only for bfloat types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be done before merging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy-n-pasted TODO from another line. I guess I'll need a help from @kulinseth because I could not find anything in the https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf that talks about it

@malfet
Copy link
Contributor Author

malfet commented Oct 1, 2024

@pytorchbot merge -f "Lint + MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 1, 2024
By even further reducing precisions of imprecise FP16 ops, introducing new BF16_LOW_PRECISION_OPS category and marking BF16 tests as xfail for `divfloor_rounding`, `floor_divide` and `remainder`.
I guess the nature of low-precision results, is that MPSGraph, unlike the rest of the PyTorch does not do accumulation over fp32 for reduction operations

Pull Request resolved: #136987
Approved by: https://github.com/albanD
ghstack dependencies: #137070
AnantGulati pushed a commit to AnantGulati/pytorch that referenced this pull request Oct 2, 2024
For Metal cast ops to comple, one need to explicitly cast to/from `bfloat` unlike for other dtypes

Tested in pytorch#136987
Pull Request resolved: pytorch#137070
Approved by: https://github.com/Skylion007
AnantGulati pushed a commit to AnantGulati/pytorch that referenced this pull request Oct 2, 2024
By even further reducing precisions of imprecise FP16 ops, introducing new BF16_LOW_PRECISION_OPS category and marking BF16 tests as xfail for `divfloor_rounding`, `floor_divide` and `remainder`.
I guess the nature of low-precision results, is that MPSGraph, unlike the rest of the PyTorch does not do accumulation over fp32 for reduction operations

Pull Request resolved: pytorch#136987
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#137070
@github-actions github-actions bot deleted the gh/malfet/29/head branch November 3, 2024 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants