KEMBAR78
[MPS] Implement `mul` operation for complex types by malfet · Pull Request #108395 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Sep 1, 2023

Stack from ghstack (oldest at bottom):

Using existing BinaryKernel template

Add mul as well as kron and outer to list of MPS ops that support complex types

This should add all the missing ops mentioned in #105665

Using existing BinaryKernel template

Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types

This should add all the missing ops mentioned in #105665

[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner September 1, 2023 03:42
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108395

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0629aba with merge base eafc058 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Sep 1, 2023
malfet added a commit that referenced this pull request Sep 1, 2023
Using existing BinaryKernel template

Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types

This should add all the missing ops mentioned in #105665

ghstack-source-id: 80a2635
Pull Request resolved: #108395
@malfet malfet requested review from a team and albanD September 1, 2023 03:42
Using existing BinaryKernel template

Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types

This should add all the missing ops mentioned in #105665

[ghstack-poisoned]
malfet added a commit that referenced this pull request Sep 2, 2023
Using existing BinaryKernel template

Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types

This should add all the missing ops mentioned in #105665

ghstack-source-id: 3f37037
Pull Request resolved: #108395
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@malfet
Copy link
Contributor Author

malfet commented Sep 10, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/malfet/51/head branch September 13, 2023 14:24
@gautierronan
Copy link
Contributor

With this PR, does this mean pytorch supports all complex operations on MPS devices? e.g. tensor multiplications, additions, and various linear algebra operations?

@dbl001
Copy link

dbl001 commented Oct 2, 2023

(AI-Feynman) davidlaxer@bluediamond BIMT % ipython
Python 3.10.13 (main, Sep 11 2023, 08:21:04) [Clang 14.0.6 ]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.15.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

   ...: # Create complex tensors on the MPS device
   ...: a = torch.randn(3, 3, dtype=torch.complex32, device=device)
   ...: b = torch.randn(3, 3, dtype=torch.complex32, device=device)
   ...: 
   ...: # Matrix multiplication
   ...: matmul_result = torch.mm(a, b)
   ...: 
   ...: # Matrix inversion
   ...: inverse_result = torch.inverse(a)
   ...: 
   ...: # Eigenvalue decomposition
   ...: eigenvalues, eigenvectors = torch.linalg.eig(a)
   ...: 
   ...: # Singular value decomposition
   ...: U, S, V = torch.linalg.svd(a)
   ...: 
   ...: print("Tensor a:\n", a)
   ...: print("Tensor b:\n", b)
   ...: print("Matrix multiplication of a and b:\n", matmul_result)
   ...: print("Inverse of a:\n", inverse_result)
   ...: print("Eigenvalues of a:\n", eigenvalues)
   ...: print("Eigenvectors of a:\n", eigenvectors)
   ...: print("Singular values of a:\n", S)
   ...: print("U from SVD of a:\n", U)
   ...: print("V from SVD of a:\n", V)
   ...: 
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 5
      2 device = torch.device("mps")
      4 # Create complex tensors on the MPS device
----> 5 a = torch.randn(3, 3, dtype=torch.complex32, device=device)
      6 b = torch.randn(3, 3, dtype=torch.complex32, device=device)
      8 # Matrix multiplication

TypeError: Trying to convert ComplexHalf to the MPS backend but it does not have support for that dtype.

   ...: 
   ...: # Create complex tensors on the MPS device
   ...: a = torch.randn(3, 3, dtype=torch.complex64, device=device)
   ...: b = torch.randn(3, 3, dtype=torch.complex64, device=device)
   ...: 
   ...: # Matrix multiplication
   ...: matmul_result = torch.mm(a, b)
   ...: 
   ...: # Matrix inversion
   ...: inverse_result = torch.inverse(a)
   ...: 
   ...: # Eigenvalue decomposition
   ...: eigenvalues, eigenvectors = torch.linalg.eig(a)
   ...: 
   ...: # Singular value decomposition
   ...: U, S, V = torch.linalg.svd(a)
   ...: 
   ...: print("Tensor a:\n", a)
   ...: print("Tensor b:\n", b)
   ...: print("Matrix multiplication of a and b:\n", matmul_result)
   ...: print("Inverse of a:\n", inverse_result)
   ...: print("Eigenvalues of a:\n", eigenvalues)
   ...: print("Eigenvectors of a:\n", eigenvectors)
   ...: print("Singular values of a:\n", S)
   ...: print("U from SVD of a:\n", U)
   ...: print("V from SVD of a:\n", V)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 5
      2 device = torch.device("mps")
      4 # Create complex tensors on the MPS device
----> 5 a = torch.randn(3, 3, dtype=torch.complex64, device=device)
      6 b = torch.randn(3, 3, dtype=torch.complex64, device=device)
      8 # Matrix multiplication

TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.

In [5]: real_tensor = torch.tensor([1.0], device='mps')

In [6]: imag_tensor = torch.tensor([2.0], device='mps')

In [7]: a = torch.complex(real_tensor, imag_tensor)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[7], line 1
----> 1 a = torch.complex(real_tensor, imag_tensor)

NotImplementedError: The operator 'aten::complex.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [8]: print(torch.__version__)
2.2.0a0+gitac3190c
Successfully installed torch-2.1.0.dev20230729 torchaudio-2.2.0.dev20231002 torchvision-0.17.0.dev20231002
(ai) davidlaxer@bluediamond top2vec % ipython
Python 3.8.13 (default, Oct 19 2022, 17:54:22) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.12.2 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

   ...: a = torch.randn(3, 3, dtype=torch.complex64, device=device)
   ...: b = torch.randn(3, 3, dtype=torch.complex64, device=device)
   ...: 
   ...: # Matrix multiplication
   ...: matmul_result = torch.mm(a, b)
   ...: 
   ...: # Matrix inversion
   ...: inverse_result = torch.inverse(a)
   ...: 
   ...: # Eigenvalue decomposition
   ...: eigenvalues, eigenvectors = torch.linalg.eig(a)
   ...: 
   ...: # Singular value decomposition
   ...: U, S, V = torch.linalg.svd(a)
   ...: 
   ...: print("Tensor a:\n", a)
   ...: print("Tensor b:\n", b)
   ...: print("Matrix multiplication of a and b:\n", matmul_result)
   ...: print("Inverse of a:\n", inverse_result)
   ...: print("Eigenvalues of a:\n", eigenvalues)
   ...: print("Eigenvectors of a:\n", eigenvectors)
   ...: print("Singular values of a:\n", S)
   ...: print("U from SVD of a:\n", U)
   ...: print("V from SVD of a:\n", V)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 5
      2 device = torch.device("mps")
      4 # Create complex tensors on the MPS device
----> 5 a = torch.randn(3, 3, dtype=torch.complex64, device=device)
      6 b = torch.randn(3, 3, dtype=torch.complex64, device=device)
      8 # Matrix multiplication

TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.

In [3]: print(torch.__version__)
2.2.0.dev20231002


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants