-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: bfloat16module: interpolationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Nearest upsampling with torch.nn.functional.interpolate does not work in bfloat16. Minimal code to reproduce.
import torch
import torch.nn.functional as F
image = torch.randn(1, 4, 32, 32).to(device="cuda", dtype=torch.bfloat16)
out = F.interpolate(image, size=(64, 64), mode="nearest")
This throws an error
File ~/.pyenv/versions/3.9.14/envs/diffusers-env/lib/python3.9/site-packages/torch/nn/functional.py:3910, in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
3908 return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
3909 if input.dim() == 4 and mode == "nearest":
-> 3910 return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
3911 if input.dim() == 5 and mode == "nearest":
3912 return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'BFloat16'
F.interpolate with nearest mode is used a lot unets which are the backbone diffusion models like stable diffusion. Due to this at the moment it's not possible to use Stable Diffusion with bfloat16 without manual casting. cf huggingface/diffusers#792
Versions
PyTorch version: 1.12.1+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4
Libc version: glibc-2.28
Python version: 3.9.14 (main, Sep 22 2022, 15:50:51) [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-4.19.0-22-cloud-amd64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.3
[pip3] torch==1.12.1+cu116
[pip3] torchaudio==0.12.1+cu116
[pip3] torchvision==0.13.1+cu116
[conda] numpy 1.19.5 py37h3e96413_3 conda-forge
AL3708, dav-ell, noamsgl, Thanos-DB, VillSnow and 3 more
Metadata
Metadata
Assignees
Labels
module: bfloat16module: interpolationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module