-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
cuDNN has been brewing its own flavor of flash attention (v2) for some time now, and it appears to be ready for integration into upstream PyTorch.
The current implementation uses the new graph API provided by the yet-to-be finalized cuDNN v1.0 frontend. It also differs from previous cuDNN frontend integrations, such as for convolution (see ##58414), in that the flash-attention implementation requires some (lightweight) runtime compilation. Of course, these will be cached after the initial compilation. Additionally, the support matrix is not as exhaustive as that of convolution---there are requirements on the strides (inputs must be contiguous, with stride 1 in the embedding dimension), inputs must have dtype bfloat16 or float16, and only hardware newer than compute capability 8.0 is supported, etc.
I've outlined the following steps to be considered for integration, including potential issues that could surface along the way.
Early performance benchmarks (bfloat16, H100):
https://docs.google.com/spreadsheets/d/17fckdkCB3JRjugevDTcSfP4cBCmaGS1qNckiuJXBm-k
- Step 0: Initial forward pass proof-of-concept. Here, we consider the forward pass to be experimentally functional, though not all shapes and compute-capabilities are well-tested. Note that we would require the cuDNN frontend 1.0 interface to be finalized and released at this point. Prototype implementation: main...eqy:pytorch:cudnnmha3 At this point, the implementation can be tried with
TORCH_CUDNN_MHA_ENABLED=1The existing context managers for selecting SDP backends would also apply to the cuDNN implementation at this point as well. - Step 1: Forward pass + backward pass and bells and whistles like dropout
- Step 2: (WE ARE HERE) Robustness, incl. shape checks before dispatching into cuDNN MHA, handling of workspace size limitations, etc.
- Step 3:
TORCH_CUDNN_MHA_ENABLED=1becomesTORCH_CUDNN_MHA_DISABLED=0, and it is enabled by default alongside existing SDP backends.
Alternatives
No response
Additional context
No response
cc @csarofeen @ptrblck @xwang233 @msaroufim @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki