-
Notifications
You must be signed in to change notification settings - Fork 134
Open
Description
Describe the bug
We noticed NaN being generated from the CUDNN SDPA backward in our training runs. After digging a bit, we narrowed it down to a minimal repro case - it originally happened with much larger tensors but I believe this repro case is indicative of the problem we run into.
Expected behavior
We expected no NaN in the gradients. Flash attention and torch SDPA compute correct outputs.
System Environment (please complete the following information):
- cudnn_frontend version: 1.14 (installed via
uv pip install "nvidia-cudnn-frontend>=1.14") - cudnn_backend version: 9.12.0.46 (installed via
uv pip install "nvidia-cudnn-cu12>=9.12" --force-reinstall) - GPU arch: H100 SXM
- cuda runtime version: host is on 12.2 / 12.3, also tried with 12.4 and with torch runtime (12.6, 12.8)
- cuda driver version: 535.183.01, also tried with 550.163.01
- host compiler: gcc/g++ 9.4
- OS: ubuntu20.04, ubuntu22.04
API logs
Please attach API logs for both cudnn_frontend and cudnn_backend.
// For cudnn_frontend
export CUDNN_FRONTEND_LOG_FLIE=fe.log
export CUDNN_FRONTEND_LOG_INFO=1
// For cudnn_backend
export CUDNN_LOGLEVEL_DBG=3
export CUDNN_LOGDEST_DBG=be.log
To Reproduce
Run the following script:
import cudnn
import torch
import math
if __name__ == "__main__":
torch.manual_seed(42)
handle = cudnn.create_handle()
b, h, s, d = 1, 1, 2, 8
dims = (b, h, s, d)
strides = (s * h * d, d, h * d, 1)
attn_scale = 1.0 / math.sqrt(d)
q_gpu = torch.tensor([[[[-1.7109, 0.7227, -1.5234, -2.1250, 2.4844, 0.4277, 2.4844,
0.3008],
[-0.9023, 1.1094, -1.6094, -1.8359, -4.2500, 1.4531, 1.3281,
1.1172]]]], device='cuda:0', dtype=torch.bfloat16)
k_gpu = torch.tensor([[[[-0.6367, 0.3945, -0.8867, -0.6250, 70.0000, 1.2109, 1.2969,
-0.1934],
[-0.7031, 1.2031, -0.7383, -0.7188, 70.5000, 1.9297, 0.7891,
1.2500]]]], device='cuda:0', dtype=torch.bfloat16)
v_gpu = torch.tensor([[[[ 0.7422, 0.1875, -0.5664, 0.2383, -0.5195, 0.3789, 0.1260,
-1.3047],
[-0.7031, -0.5859, 0.3613, -0.2852, 0.1729, 0.4023, 1.0312,
-1.0000]]]], device='cuda:0', dtype=torch.bfloat16)
dO_gpu = torch.tensor([[[[ 3.8126e-09, -1.3970e-08, 5.7335e-09, -1.0303e-08, -4.9477e-09,
1.3388e-08, -1.1292e-08, -1.4115e-09],
[-1.3504e-08, -1.0594e-08, 1.2282e-08, 1.4144e-08, 5.0932e-09,
1.3853e-08, 1.4727e-08, 2.0606e-08]]]], device='cuda:0',
dtype=torch.bfloat16)
o_gpu = torch.empty(b * s * h * d).to(q_gpu.dtype).cuda().as_strided(dims, strides)
stats_gpu = torch.empty(b, h, s, 1).float().cuda()
dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)
graph_forward = cudnn.pygraph(
io_data_type=cudnn.data_type.BFLOAT16,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_forward = graph_forward.tensor_like(q_gpu.detach())
k_forward = graph_forward.tensor_like(k_gpu.detach())
v_forward = graph_forward.tensor_like(v_gpu.detach())
# training mode is enabled with generate_stats=True
# causal mask is enabled
o_forward, stats_forward = graph_forward.sdpa(
name="sdpa",
q=q_forward,
k=k_forward,
v=v_forward,
generate_stats=True,
attn_scale=attn_scale,
use_causal_mask=False,
)
o_forward.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_dim(stats_gpu.size()).set_stride(stats_gpu.stride())
stats_forward.set_data_type(cudnn.data_type.FLOAT)
graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()
graph_backward = cudnn.pygraph(
io_data_type=cudnn.data_type.BFLOAT16,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_backward = graph_backward.tensor_like(q_gpu)
k_backward = graph_backward.tensor_like(k_gpu)
v_backward = graph_backward.tensor_like(v_gpu)
o_backward = graph_backward.tensor_like(o_gpu)
dO_backward = graph_backward.tensor_like(dO_gpu)
stats_backward = graph_backward.tensor_like(stats_gpu)
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
name="sdpa_backward",
q=q_backward,
k=k_backward,
v=v_backward,
o=o_backward,
dO=dO_backward,
stats=stats_backward,
attn_scale=attn_scale,
use_causal_mask=False,
)
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()
workspace_size = max(
graph_forward.get_workspace_size(),
graph_backward.get_workspace_size(),
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
variant_pack_forward = {
q_forward: q_gpu,
k_forward: k_gpu,
v_forward: v_gpu,
o_forward: o_gpu,
stats_forward: stats_gpu,
}
graph_forward.execute(variant_pack_forward, workspace)
torch.cuda.synchronize()
variant_pack_backward = {
q_backward: q_gpu,
k_backward: k_gpu,
v_backward: v_gpu,
o_backward: o_gpu,
dO_backward: dO_gpu,
stats_backward: stats_gpu,
dQ_backward: dQ_gpu,
dK_backward: dK_gpu,
dV_backward: dV_gpu,
}
graph_backward.execute(variant_pack_backward, workspace)
torch.cuda.synchronize()
q_ref = q_gpu.detach().requires_grad_()
k_ref = k_gpu.detach().requires_grad_()
v_ref = v_gpu.detach().requires_grad_()
dO_ref = dO_gpu.detach()
o_ref = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, is_causal=False, scale=attn_scale
)
# This works perfectly - forward passes are identical also with FA
torch.testing.assert_close(o_ref, o_gpu, atol=5e-3, rtol=3e-3)
dQ_ref, dK_ref, dV_ref = torch.autograd.grad(
outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref]
)
# Here dQ_gpu has a row of nans that dQ_ref does not have
# ipdb> dQ_ref.v
# tensor[1, 1, 2, 8] bf16 n=16 x∈[-2.4301698e-09, 6.9267116e-09] μ=9.2040864e-10 σ=2.2264430e-09 cuda:0
# tensor([[[[-1.2392e-11, 1.5098e-10, 2.7740e-11, -1.7508e-11, 9.3223e-11,
# 1.3370e-10, -9.4587e-11, 2.6921e-10],
# [-3.1832e-10, 3.8708e-09, 7.1304e-10, -4.4929e-10, 2.4011e-09,
# 3.4488e-09, -2.4302e-09, 6.9267e-09]]]], device='cuda:0',
# dtype=torch.bfloat16)
# ipdb> dQ_gpu.v
# tensor[1, 1, 2, 8] bf16 n=16 x∈[-9.6406438e-11, 2.6739144e-10] μ=5.2281734e-11 σ=1.1997352e-10 NaN! cuda:0
# tensor([[[[-1.1141e-11, 1.4916e-10, 2.9104e-11, -1.6257e-11, -3.4561e-11,
# 1.3097e-10, -9.6406e-11, 2.6739e-10],
# [ nan, nan, nan, nan, nan,
# nan, nan, nan]]]], device='cuda:0',
# dtype=torch.bfloat16)
torch.testing.assert_close(dQ_ref, dQ_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dK_ref, dK_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dV_ref, dV_gpu, atol=5e-3, rtol=3e-3)
Metadata
Metadata
Assignees
Labels
No labels