KEMBAR78
CUDNN SDPA backward results in NaN outputs · Issue #160 · NVIDIA/cudnn-frontend · GitHub
Skip to content

CUDNN SDPA backward results in NaN outputs #160

@thomasneff

Description

@thomasneff

Describe the bug
We noticed NaN being generated from the CUDNN SDPA backward in our training runs. After digging a bit, we narrowed it down to a minimal repro case - it originally happened with much larger tensors but I believe this repro case is indicative of the problem we run into.

Expected behavior
We expected no NaN in the gradients. Flash attention and torch SDPA compute correct outputs.

System Environment (please complete the following information):

  • cudnn_frontend version: 1.14 (installed via uv pip install "nvidia-cudnn-frontend>=1.14")
  • cudnn_backend version: 9.12.0.46 (installed via uv pip install "nvidia-cudnn-cu12>=9.12" --force-reinstall )
  • GPU arch: H100 SXM
  • cuda runtime version: host is on 12.2 / 12.3, also tried with 12.4 and with torch runtime (12.6, 12.8)
  • cuda driver version: 535.183.01, also tried with 550.163.01
  • host compiler: gcc/g++ 9.4
  • OS: ubuntu20.04, ubuntu22.04

API logs
Please attach API logs for both cudnn_frontend and cudnn_backend.

// For cudnn_frontend
export CUDNN_FRONTEND_LOG_FLIE=fe.log
export CUDNN_FRONTEND_LOG_INFO=1

// For cudnn_backend
export CUDNN_LOGLEVEL_DBG=3
export CUDNN_LOGDEST_DBG=be.log

be.log

fe.log

To Reproduce
Run the following script:


import cudnn
import torch
import math

if __name__ == "__main__":
    torch.manual_seed(42)

    handle = cudnn.create_handle()
    
    b, h, s, d = 1, 1, 2, 8
    dims = (b, h, s, d)
    strides = (s * h * d, d, h * d, 1)
    attn_scale = 1.0 / math.sqrt(d)
    
    q_gpu = torch.tensor([[[[-1.7109,  0.7227, -1.5234, -2.1250,  2.4844,  0.4277,  2.4844,
            0.3008],
          [-0.9023,  1.1094, -1.6094, -1.8359, -4.2500,  1.4531,  1.3281,
            1.1172]]]], device='cuda:0', dtype=torch.bfloat16)
    
    k_gpu = torch.tensor([[[[-0.6367,  0.3945, -0.8867, -0.6250, 70.0000,  1.2109,  1.2969,
           -0.1934],
          [-0.7031,  1.2031, -0.7383, -0.7188, 70.5000,  1.9297,  0.7891,
            1.2500]]]], device='cuda:0', dtype=torch.bfloat16)
    
    v_gpu = torch.tensor([[[[ 0.7422,  0.1875, -0.5664,  0.2383, -0.5195,  0.3789,  0.1260,
           -1.3047],
          [-0.7031, -0.5859,  0.3613, -0.2852,  0.1729,  0.4023,  1.0312,
           -1.0000]]]], device='cuda:0', dtype=torch.bfloat16)
    
    dO_gpu = torch.tensor([[[[ 3.8126e-09, -1.3970e-08,  5.7335e-09, -1.0303e-08, -4.9477e-09,
            1.3388e-08, -1.1292e-08, -1.4115e-09],
          [-1.3504e-08, -1.0594e-08,  1.2282e-08,  1.4144e-08,  5.0932e-09,
            1.3853e-08,  1.4727e-08,  2.0606e-08]]]], device='cuda:0',
       dtype=torch.bfloat16)
   
    o_gpu = torch.empty(b * s * h * d).to(q_gpu.dtype).cuda().as_strided(dims, strides)
    stats_gpu = torch.empty(b, h, s, 1).float().cuda()
    
    dQ_gpu = torch.empty_like(q_gpu)
    dK_gpu = torch.empty_like(k_gpu)
    dV_gpu = torch.empty_like(v_gpu)
    
    
    graph_forward = cudnn.pygraph(
        io_data_type=cudnn.data_type.BFLOAT16,
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )

    q_forward = graph_forward.tensor_like(q_gpu.detach())
    k_forward = graph_forward.tensor_like(k_gpu.detach())
    v_forward = graph_forward.tensor_like(v_gpu.detach())

    # training mode is enabled with generate_stats=True
    # causal mask is enabled
    o_forward, stats_forward = graph_forward.sdpa(
        name="sdpa",
        q=q_forward,
        k=k_forward,
        v=v_forward,
        generate_stats=True,
        attn_scale=attn_scale,
        use_causal_mask=False,
    )

    o_forward.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride())
    stats_forward.set_output(True).set_dim(stats_gpu.size()).set_stride(stats_gpu.stride())
    stats_forward.set_data_type(cudnn.data_type.FLOAT)

    graph_forward.validate()
    graph_forward.build_operation_graph()
    graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
    graph_forward.check_support()
    graph_forward.build_plans()
            
    graph_backward = cudnn.pygraph(
        io_data_type=cudnn.data_type.BFLOAT16,
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )

    q_backward = graph_backward.tensor_like(q_gpu)
    k_backward = graph_backward.tensor_like(k_gpu)
    v_backward = graph_backward.tensor_like(v_gpu)
    o_backward = graph_backward.tensor_like(o_gpu)
    dO_backward = graph_backward.tensor_like(dO_gpu)
    stats_backward = graph_backward.tensor_like(stats_gpu)

    dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
        name="sdpa_backward",
        q=q_backward,
        k=k_backward,
        v=v_backward,
        o=o_backward,
        dO=dO_backward,
        stats=stats_backward,
        attn_scale=attn_scale,
        use_causal_mask=False,
    )

    dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
    dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
    dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())

    graph_backward.validate()
    graph_backward.build_operation_graph()
    graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
    graph_backward.check_support()
    graph_backward.build_plans()
    
    workspace_size = max(
        graph_forward.get_workspace_size(),
        graph_backward.get_workspace_size(),
    )
    workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
    
    variant_pack_forward = {
        q_forward: q_gpu,
        k_forward: k_gpu,
        v_forward: v_gpu,
        o_forward: o_gpu,
        stats_forward: stats_gpu,
    }

    graph_forward.execute(variant_pack_forward, workspace)
    torch.cuda.synchronize()
    
    
    variant_pack_backward = {
        q_backward: q_gpu,
        k_backward: k_gpu,
        v_backward: v_gpu,
        o_backward: o_gpu,
        dO_backward: dO_gpu,
        stats_backward: stats_gpu,
        dQ_backward: dQ_gpu,
        dK_backward: dK_gpu,
        dV_backward: dV_gpu,
    }

    graph_backward.execute(variant_pack_backward, workspace)
    torch.cuda.synchronize()
    
    
    q_ref = q_gpu.detach().requires_grad_()
    k_ref = k_gpu.detach().requires_grad_()
    v_ref = v_gpu.detach().requires_grad_()
    dO_ref = dO_gpu.detach()

    o_ref = torch.nn.functional.scaled_dot_product_attention(
        q_ref, k_ref, v_ref, is_causal=False, scale=attn_scale
    )
    
    # This works perfectly - forward passes are identical also with FA
    torch.testing.assert_close(o_ref, o_gpu, atol=5e-3, rtol=3e-3)

    dQ_ref, dK_ref, dV_ref = torch.autograd.grad(
        outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref]
    )
    
    # Here dQ_gpu has a row of nans that dQ_ref does not have
    # ipdb> dQ_ref.v
    # tensor[1, 1, 2, 8] bf16 n=16 x∈[-2.4301698e-09, 6.9267116e-09] μ=9.2040864e-10 σ=2.2264430e-09 cuda:0
    # tensor([[[[-1.2392e-11,  1.5098e-10,  2.7740e-11, -1.7508e-11,  9.3223e-11,
    #             1.3370e-10, -9.4587e-11,  2.6921e-10],
    #         [-3.1832e-10,  3.8708e-09,  7.1304e-10, -4.4929e-10,  2.4011e-09,
    #             3.4488e-09, -2.4302e-09,  6.9267e-09]]]], device='cuda:0',
    #     dtype=torch.bfloat16)
    # ipdb> dQ_gpu.v
    # tensor[1, 1, 2, 8] bf16 n=16 x∈[-9.6406438e-11, 2.6739144e-10] μ=5.2281734e-11 σ=1.1997352e-10 NaN! cuda:0
    # tensor([[[[-1.1141e-11,  1.4916e-10,  2.9104e-11, -1.6257e-11, -3.4561e-11,
    #             1.3097e-10, -9.6406e-11,  2.6739e-10],
    #         [        nan,         nan,         nan,         nan,         nan,
    #                 nan,         nan,         nan]]]], device='cuda:0',
    #     dtype=torch.bfloat16)
    torch.testing.assert_close(dQ_ref, dQ_gpu, atol=5e-3, rtol=3e-3)
    torch.testing.assert_close(dK_ref, dK_gpu, atol=5e-3, rtol=3e-3)
    torch.testing.assert_close(dV_ref, dV_gpu, atol=5e-3, rtol=3e-3) 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions