KEMBAR78
[EZ][MPS] Extend `arange` to bfloat16 by malfet · Pull Request #136754 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Sep 26, 2024

Stack from ghstack (oldest at bottom):

RangeFactories class is the only one that uses AT_DISPATCH_MPS_TYPES

Fixes #136624

RangeFactories class is the only one that uses `AT_DISPATCH_MPS_TYPES`

Fixes #136624

[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner September 26, 2024 13:46
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136754

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5f01c05 with merge base 3c542ce (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Sep 26, 2024
malfet added a commit that referenced this pull request Sep 26, 2024
RangeFactories class is the only one that uses `AT_DISPATCH_MPS_TYPES`

Fixes #136624

ghstack-source-id: bd377f8
Pull Request resolved: #136754
@malfet malfet requested review from Skylion007 and albanD September 26, 2024 13:47
@malfet
Copy link
Contributor Author

malfet commented Sep 26, 2024

@pytorchbot merge -f "MPS tests + lint are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
As minimal supported OS has been rasied to MacOS 13, some basic complex operations  should be supported, and the rest could be `xfailed`
Pull Request resolved: #136755
Approved by: https://github.com/Skylion007
ghstack dependencies: #136754
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
`norm` and `masked.normalize` still have to stay in the list
Pull Request resolved: #136821
Approved by: https://github.com/Skylion007
ghstack dependencies: #136754, #136755
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
This was a stupid cast error that caused MPSGraph to crash with the following exception
```
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
```
Pull Request resolved: #136822
Approved by: https://github.com/Skylion007
ghstack dependencies: #136754, #136755, #136821
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
Before that attempt to run something like
```
% python -c "import torch;dev,dt='mps',torch.int; print(torch.normal(mean=torch.arange(1., 11., device=dev, dtype=dt), std=torch.arange(10, 0, -1, device=dev, dtype=dt)))"
```
Resulted in hard error
```
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
```
After the change, it raises a nice type error
Pull Request resolved: #136863
Approved by: https://github.com/Skylion007
ghstack dependencies: #136754, #136755, #136821, #136822
@Vargol
Copy link

Vargol commented Oct 12, 2024

Is there any chance of getting this in 2.5.0 or 2.5.1. if its to late for 2.5.0 ?

@Vargol
Copy link

Vargol commented Oct 23, 2024

Sorry to ask again, but there's a 2.5.1 tracker (#138613) now, any chance of getting this cherry picked for 2.5.1.

@malfet
Copy link
Contributor Author

malfet commented Oct 23, 2024

Sorry to ask again, but there's a 2.5.1 tracker (#138613) now, any chance of getting this cherry picked for 2.5.1.

2.5.1 is an emergency release that accepts only regressions, and despite me wanting to have it in the release, it does not meet the criteria. (as there were no bfloat16 support for arange in 2.4)

@github-actions github-actions bot deleted the gh/malfet/15/head branch November 23, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants