-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[cuDNN] cuDNN SDPA (Flash Attention) Backward #122510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122510
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit b53102b with merge base 91d565d ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
auto [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] = graph_and_tensors_backward_values; | ||
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = { | ||
// inputs | ||
{Q, q.data_ptr()}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't some of these be const_data_ptr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked this again and it currently seems to be a limitation of cuDNN's variantpack
, which only accepts void *
pointers.
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); | ||
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); | ||
return std::make_tuple( | ||
mha_graph, Q, K, V, attn_scale, Seed, Offset, O, DO, STATS, DQ, DK, DV); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alot of these should be std::move()
tags: nondeterministic_seeded | ||
|
||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) | ||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a special function? @jbschlosser and @andrewor14 have spent litteral months removing these cudnn variants for conv and batchnorm. I really don't think we should be doing the same again with sdpa...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in this case it's because cuDNN is just a backend in the same way that e.g., flash, mem-efficient, and math are separate backends because the rules for dispatching between them aren't expected to overlap (e.g., cuDNN SDPA isn't expected to exactly cover the support matrix of flash or mem-efficient attention)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Me and @albanD had a long discussion about this on Friday. Historically there has been a similar pattern for aten.op bloat followed by the need to unify a large number of backend behind composite explicit ops. The two exemplars of this are torch's cuDNN op, and batch norm with which consolidation is still ongoing by @andrewor14.
The conclusion was that for now we are okay with paying the backend tech debt (added ops and added coverage surface) for the sake of velocity.
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
d382f3d
to
8bbdfa0
Compare
const Tensor& dropoutoffset, | ||
cudnnHandle_t& handle, | ||
MHAParams& params) { | ||
auto dtype = fe::DataType_t::HALF; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this support fp32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to these docs: https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-bprop not at the moment, but I'll ask the cuDNN team about the roadmap.
} | ||
auto mha_graph = std::make_shared<fe::graph::Graph>(); | ||
mha_graph->set_io_data_type(dtype) | ||
.set_intermediate_data_type(fe::DataType_t::FLOAT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe leave a comment at the top with some of the choices made about the graph construction
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
std::vector<int64_t>(v.strides().begin(), v.strides().end()))); | ||
auto attn_scale = | ||
mha_graph->tensor(fe::graph::Tensor_attributes() | ||
.set_name("attn_scale") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this support arbitrary attention bias tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think newer versions (of cuDNN) do, and I will add that in a follow-up. Currently we might be limited by existing builds that use e.g., cuDNN 8.9.2.
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
std::vector<int64_t>(o.strides().begin(), o.strides().end()))); | ||
auto STATS = mha_graph->tensor( | ||
fe::graph::Tensor_attributes() | ||
.set_name("stats") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The softmaxstats is the logsumexp of the attention scores right? maybe softmaxstats is the nomenclature for cuDNN thats fine but if you wanted to align with the other kernels
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
.set_stride( | ||
std::vector<int64_t>(dO.strides().begin(), dO.strides().end()))); | ||
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() | ||
.set_name("flash_attention_backward") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe different name? to remove some confusion between the existing flash impl
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
auto workspace_size = mha_graph->get_workspace_size(); | ||
auto workspace_ptr = | ||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); | ||
TORCH_INTERNAL_ASSERT(!workspace_size || workspace_ptr.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI that has bitten some people TORCH_INTERNAL_ASSERT is only active when building with debug. This first assert probably makes sense to be internal, but do you think the execution of the graph should be a normal TORCH_CHECK ?
dv/*Tensor& dV*/, | ||
philox_seed/*Tensor& dropoutseed*/, | ||
philox_offset/*Tensor& dropoutoffset*/); | ||
return std::make_tuple(dq, dk, dv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::move?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, mostly just small nits, and if you have any perf numbers that would be great!
Let me see if I can just rerun my existing forward script on forward + backward now... |
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
Successfully rebased |
021fc2b
to
b53102b
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#113713 currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs Will also collect benchmark data, CC @drisspg Pull Request resolved: pytorch#122510 Approved by: https://github.com/drisspg
This reverts commit 64af899. Reverted pytorch#122510 on behalf of https://github.com/jeanschmidt due to Breaking amd gpu builds ([comment](pytorch#122510 (comment)))
#113713 currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs Will also collect benchmark data, CC @drisspg Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Pull Request resolved: #122510 Approved by: https://github.com/drisspg
#113713
currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs
Will also collect benchmark data,
CC @drisspg
cc @csarofeen @ptrblck @xwang233