-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
#154696 adds MPS support for ConvTranspose3D for FP32 and Complex64 types. However, amp.autocast tries to use a half-precision floats for this operation, which is not supported:
import torch
from torch import nn
device = torch.device('mps')
with torch.amp.autocast(device_type=device.type):
m = nn.ConvTranspose3d(16, 33, 3, stride=2)
m.to(device)
x = torch.randn(20, 16, 10, 50, 100).to(device)
u = m(x)Traceback (most recent call last):
File "/Users/REDACTED/test_amp.py", line 10, in <module>
u = m(x)
^^^^
File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 1344, in forward
return F.conv_transpose3d(
^^^^^^^^^^^^^^^^^^^
RuntimeError: ConvTranspose 3D with BF16 or FP16 types is not supported on MPS
And of course, when the dtype is manually specified as FP32, autocast disables itself (understandably):
with torch.amp.autocast(device_type=device.type, dtype=torch.float32):UserWarning: In MPS autocast, but the target dtype is not supported. Disabling autocast.
MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.
I don't know much about how AMP is implemented, but it seems like it should be possible to mark ConvTranspose3D as requiring FP32 for autocast purposes when running on MPS, at least until the half-precision variants are supported. Is this feasible?
Versions
PyTorch version: 2.9.0.dev20250811
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.6 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.31.6
Libc version: N/A
Python version: 3.11.11 (main, Dec 3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)] (64-bit runtime)
Python platform: macOS-15.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] torch==2.9.0.dev20250811
[pip3] torchaudio==2.8.0.dev20250811
[pip3] torchinfo==1.8.0
[pip3] torchio==0.20.17
[pip3] torchmetrics==1.7.3
[pip3] torchvision==0.24.0.dev20250811
[conda] Could not collect
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen