KEMBAR78
[cuDNN][cuDNN V8 API] cuDNN Flash-Attention Upstreaming RFC/tracking issue · Issue #113713 · pytorch/pytorch · GitHub
Skip to content

[cuDNN][cuDNN V8 API] cuDNN Flash-Attention Upstreaming RFC/tracking issue #113713

@eqy

Description

@eqy

🚀 The feature, motivation and pitch

cuDNN has been brewing its own flavor of flash attention (v2) for some time now, and it appears to be ready for integration into upstream PyTorch.

The current implementation uses the new graph API provided by the yet-to-be finalized cuDNN v1.0 frontend. It also differs from previous cuDNN frontend integrations, such as for convolution (see ##58414), in that the flash-attention implementation requires some (lightweight) runtime compilation. Of course, these will be cached after the initial compilation. Additionally, the support matrix is not as exhaustive as that of convolution---there are requirements on the strides (inputs must be contiguous, with stride 1 in the embedding dimension), inputs must have dtype bfloat16 or float16, and only hardware newer than compute capability 8.0 is supported, etc.
 
I've outlined the following steps to be considered for integration, including potential issues that could surface along the way.
Early performance benchmarks (bfloat16, H100):
https://docs.google.com/spreadsheets/d/17fckdkCB3JRjugevDTcSfP4cBCmaGS1qNckiuJXBm-k

  • Step 0: Initial forward pass proof-of-concept. Here, we consider the forward pass to be experimentally functional, though not all shapes and compute-capabilities are well-tested. Note that we would require the cuDNN frontend 1.0 interface to be finalized and released at this point. Prototype implementation: main...eqy:pytorch:cudnnmha3 At this point, the implementation can be tried with TORCH_CUDNN_MHA_ENABLED=1 The existing context managers for selecting SDP backends would also apply to the cuDNN implementation at this point as well.
  • Step 1: Forward pass + backward pass and bells and whistles like dropout
  • Step 2: (WE ARE HERE) Robustness, incl. shape checks before dispatching into cuDNN MHA, handling of workspace size limitations, etc.
  • Step 3: TORCH_CUDNN_MHA_ENABLED=1 becomes TORCH_CUDNN_MHA_DISABLED=0, and it is enabled by default alongside existing SDP backends.

CC @drisspg @ptrblck @malfet

Alternatives

No response

Additional context

No response

cc @csarofeen @ptrblck @xwang233 @msaroufim @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: cudnnRelated to torch.backends.cudnn, and CuDNN supportmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions