-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
While updating to torch nightly, we noticed some internal tests had vastly lower precision. I've determined that stable Torch uses ieee precision by default, but since 19aa8eb, the torch nightlies have defaulted to tf32 precision. I don't think this is intended? I thought that usually fp32 is to be computed in fp32 precision and you should have to opt-in to tf32 math.
import torch
print(torch.__version__)
import torch.nn.attention.flex_attention
torch.compile(torch.nn.attention.flex_attention.flex_attention)
print(torch._inductor.kernel.flex_attention.get_float32_precision())On stable, this prints:
2.8.0+cu128
'ieee'
However, starting on the 20250729 nightly, it's tf32.
2.9.0.dev20250729+cu128
'tf32'
The previous nightly had a warning:
2.9.0.dev20250728+cu128
/home/ubuntu/.local/lib/python3.10/site-packages/torch/__init__.py:1546: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
return _C._get_float32_matmul_precision()
'ieee'
Versions
My comparison is between pip install torch==2.9.0.dev20250728+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128 and pip install torch==2.9.0.dev20250729+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng