KEMBAR78
[MPS] Restrict MSELoss to floating types by malfet · Pull Request #139960 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Nov 7, 2024

Stack from ghstack (oldest at bottom):

Becuase if invoked with long type it crahses deep in MPSGraph framework and to keep parity with CPU

Add test that validates that if dtype is not floating, both CPU and MPS implementations will error out
Fix function name for mse_loss_out_mps as __func__ for any structured op implementation is impl

Fixes #139723

[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner November 7, 2024 02:02
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139960

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1cad70b with merge base 59cf4bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Nov 7, 2024
@malfet malfet requested a review from manuelcandales November 7, 2024 02:03
@malfet malfet added the topic: bug fixes topic category label Nov 7, 2024
[ghstack-poisoned]
malfet added a commit that referenced this pull request Nov 7, 2024
Becuase if invoked with long type it crahses deep in MPSGraph framework and to keep parity with CPU

Add test that validates that if dtype is not floating, both CPU and MPS implementations will error out
Fix function name for `mse_loss_out_mps` as `__func__` for any structured op implementation is `impl`

Fixes #139723

ghstack-source-id: 6cd80ca
Pull Request resolved: #139960
@malfet
Copy link
Contributor Author

malfet commented Nov 8, 2024

@pytorchbot merge -f "MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Becuase if invoked with long type it crahses deep in MPSGraph framework and to keep parity with CPU

Add test that validates that if dtype is not floating, both CPU and MPS implementations will error out
Fix function name for `mse_loss_out_mps` as `__func__` for any structured op implementation is `impl`

Fixes pytorch#139723
Pull Request resolved: pytorch#139960
Approved by: https://github.com/kimishpatel
ghstack dependencies: pytorch#139961, pytorch#139959
@github-actions github-actions bot deleted the gh/malfet/52/head branch December 8, 2024 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants