KEMBAR78
Use FP32 for ConvTranspose3D when using autocast on MPS · Issue #160332 · pytorch/pytorch · GitHub
Skip to content

Use FP32 for ConvTranspose3D when using autocast on MPS #160332

@csparker247

Description

@csparker247

🐛 Describe the bug

#154696 adds MPS support for ConvTranspose3D for FP32 and Complex64 types. However, amp.autocast tries to use a half-precision floats for this operation, which is not supported:

import torch
from torch import nn

device = torch.device('mps')

with torch.amp.autocast(device_type=device.type):
    m = nn.ConvTranspose3d(16, 33, 3, stride=2)
    m.to(device)
    x = torch.randn(20, 16, 10, 50, 100).to(device)
    u = m(x)
Traceback (most recent call last):
  File "/Users/REDACTED/test_amp.py", line 10, in <module>
    u = m(x)
        ^^^^
  File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/REDACTED/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 1344, in forward
    return F.conv_transpose3d(
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: ConvTranspose 3D with BF16 or FP16 types is not supported on MPS

And of course, when the dtype is manually specified as FP32, autocast disables itself (understandably):

with torch.amp.autocast(device_type=device.type, dtype=torch.float32):
UserWarning: In MPS autocast, but the target dtype is not supported. Disabling autocast.
MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.

I don't know much about how AMP is implemented, but it seems like it should be possible to mark ConvTranspose3D as requiring FP32 for autocast purposes when running on MPS, at least until the half-precision variants are supported. Is this feasible?

Versions

PyTorch version: 2.9.0.dev20250811
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.6 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.31.6
Libc version: N/A

Python version: 3.11.11 (main, Dec 3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)] (64-bit runtime)
Python platform: macOS-15.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] torch==2.9.0.dev20250811
[pip3] torchaudio==2.8.0.dev20250811
[pip3] torchinfo==1.8.0
[pip3] torchio==0.20.17
[pip3] torchmetrics==1.7.3
[pip3] torchvision==0.24.0.dev20250811
[conda] Could not collect

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

actionablemodule: amp (automated mixed precision)autocastmodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions