KEMBAR78
FlexAttention default precision decreased from ieee to tf32 since stable · Issue #161022 · pytorch/pytorch · GitHub
Skip to content

FlexAttention default precision decreased from ieee to tf32 since stable #161022

@leijurv

Description

@leijurv

🐛 Describe the bug

While updating to torch nightly, we noticed some internal tests had vastly lower precision. I've determined that stable Torch uses ieee precision by default, but since 19aa8eb, the torch nightlies have defaulted to tf32 precision. I don't think this is intended? I thought that usually fp32 is to be computed in fp32 precision and you should have to opt-in to tf32 math.

import torch
print(torch.__version__)
import torch.nn.attention.flex_attention
torch.compile(torch.nn.attention.flex_attention.flex_attention)
print(torch._inductor.kernel.flex_attention.get_float32_precision())

On stable, this prints:

2.8.0+cu128
'ieee'

However, starting on the 20250729 nightly, it's tf32.

2.9.0.dev20250729+cu128
'tf32'

The previous nightly had a warning:

2.9.0.dev20250728+cu128
/home/ubuntu/.local/lib/python3.10/site-packages/torch/__init__.py:1546: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  return _C._get_float32_matmul_precision()
'ieee'

Versions

My comparison is between pip install torch==2.9.0.dev20250728+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128 and pip install torch==2.9.0.dev20250729+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions