KEMBAR78
[MPS] index_add inconsistency on complex tensors · Issue #160845 · pytorch/pytorch · GitHub
Skip to content

[MPS] index_add inconsistency on complex tensors #160845

@Isalia20

Description

@Isalia20

🐛 Describe the bug

Reproducer:

import torch

torch.manual_seed(0)
device = "mps"

shape = (400, )
idx = torch.arange(400, dtype=torch.long)
dim = 0


t_mps = torch.zeros(shape, dtype=torch.complex64, device=device)
t_cpu = torch.zeros(shape, dtype=torch.complex64, device="cpu")
trailing = shape[dim+1:]
src_shape = (len(idx),) + trailing
src_imag = torch.randn(src_shape, dtype=torch.float32, device=device)
src_real = torch.zeros_like(src_imag)
src = torch.complex(src_real, src_imag)
t_mps.index_add_(dim, idx.to(device), src)
t_cpu.index_add_(dim, idx.cpu(), src.cpu())

print("MPS imag sum:", t_mps.imag.abs().sum().item())
print("CPU imag sum:", t_cpu.imag.abs().sum().item())
print("max abs diff:", (t_mps.cpu() - t_cpu).abs().max().item())

Outputs:
MPS imag sum: 0.0
CPU imag sum: 321.1763916015625
max abs diff: 2.923485279083252

Discovered when implementing this PR:
#160839

Versions

PyTorch version: 2.9.0a0+git33396a6
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 3.31.2
Libc version: N/A

Python version: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.16.0
[pip3] mypy_extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.20.1
[pip3] optree==0.13.0
[pip3] torch==2.9.0a0+git33396a6
[pip3] torchvision==0.23.0a0+ee6104d
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch-lightning 2.5.1.post0 pypi_0 pypi
[conda] pytorch-metric-learning 2.8.1 pypi_0 pypi
[conda] torch 2.9.0a0+git33396a6 dev_0
[conda] torchmetrics 1.7.1 pypi_0 pypi
[conda] torchvision 0.20.1 pypi_0 pypi

cc @ezyang @anjali411 @dylanbespalko @mruberry @nikitaved @amjames @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

module: complexRelated to complex number support in PyTorchmodule: correctness (silent)issue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions