KEMBAR78
SDPA: CUDNN backend error w/ q_seq_len = 1 · Issue #138529 · pytorch/pytorch · GitHub
Skip to content

SDPA: CUDNN backend error w/ q_seq_len = 1 #138529

@drisspg

Description

@drisspg

Summary

Repro script

import torch
import torch.nn as nn
import torch.nn.functional as F


q = torch.randn(1, 16, 1, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(1, 16, 2**16, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(1, 16, 2**16, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)


from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)
    out.backward(torch.ones_like(out))

Error:

/home/drisspg/meta/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous tensor which will increase memory usage... (Triggered internally at /home/drisspg/meta/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:664.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/drisspg/meta/scripts/sdpa/repro_gqa.py", line 15, in <module>
    out.sum().backward()
  File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 624, in backward
    torch.autograd.backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

cc @ptrblck @msaroufim @eqy @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions