- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
          [EZ][BE] Add torch.complex to MPS_DTYPES
          #136755
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As minimal supported OS has been rasied to MacOS 13, complex datatypes should be supported [ghstack-poisoned]
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136755
 Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit b81da02 with merge base 3c542ce ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
 
 This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
As minimal supported OS has been rasied to MacOS 13, complex datatypes should be supported [ghstack-poisoned]
As minimal supported OS has been rasied to MacOS 13, some basic complex operations should be supported, and the rest could be `xfailed` [ghstack-poisoned]
As minimal supported OS has been rasied to MacOS 13, some basic complex operations should be supported, and the rest could be `xfailed` [ghstack-poisoned]
| @pytorchbot merge -f *MPS tests are green" | 
| ❌ 🤖 pytorchbot command failed:  | 
| @pytorchbot merge -f "MPS tests are green" | 
| Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).  Please use  Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team | 
`norm` and `masked.normalize` still have to stay in the list Pull Request resolved: #136821 Approved by: https://github.com/Skylion007 ghstack dependencies: #136754, #136755
This was a stupid cast error that caused MPSGraph to crash with the following exception ``` (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32> (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32> /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification' ``` Pull Request resolved: #136822 Approved by: https://github.com/Skylion007 ghstack dependencies: #136754, #136755, #136821
Before that attempt to run something like ``` % python -c "import torch;dev,dt='mps',torch.int; print(torch.normal(mean=torch.arange(1., 11., device=dev, dtype=dt), std=torch.arange(10, 0, -1, device=dev, dtype=dt)))" ``` Resulted in hard error ``` (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32> (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32> /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification' ``` After the change, it raises a nice type error Pull Request resolved: #136863 Approved by: https://github.com/Skylion007 ghstack dependencies: #136754, #136755, #136821, #136822
Stack from ghstack (oldest at bottom):
torch.complexto MPS_DTYPES #136755arangeto bfloat16 #136754As minimal supported OS has been rasied to MacOS 13, some basic complex operations should be supported, and the rest could be
xfailed