KEMBAR78
[cuDNN] cuDNN SDPA (Flash Attention) Backward by eqy · Pull Request #122510 · pytorch/pytorch · GitHub
Skip to content

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Mar 22, 2024

#113713
currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs

Will also collect benchmark data,

CC @drisspg

cc @csarofeen @ptrblck @xwang233

@eqy eqy added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: cuda Related to torch.cuda, and CUDA support in general open source ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category module: multi-headed-attention labels Mar 22, 2024
@eqy eqy requested review from albanD and soulitzer as code owners March 22, 2024 18:25
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122510

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit b53102b with merge base 91d565d (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy changed the title [WIP][cuDNN] cuDNN SDPA Backward [WIP][cuDNN] cuDNN SDPA (Flash Attention) Backward Mar 22, 2024
auto [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] = graph_and_tensors_backward_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't some of these be const_data_ptr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked this again and it currently seems to be a limitation of cuDNN's variantpack, which only accepts void * pointers.

AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
mha_graph, Q, K, V, attn_scale, Seed, Offset, O, DO, STATS, DQ, DK, DV);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alot of these should be std::move()

tags: nondeterministic_seeded

- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a special function? @jbschlosser and @andrewor14 have spent litteral months removing these cudnn variants for conv and batchnorm. I really don't think we should be doing the same again with sdpa...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case it's because cuDNN is just a backend in the same way that e.g., flash, mem-efficient, and math are separate backends because the rules for dispatching between them aren't expected to overlap (e.g., cuDNN SDPA isn't expected to exactly cover the support matrix of flash or mem-efficient attention)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me and @albanD had a long discussion about this on Friday. Historically there has been a similar pattern for aten.op bloat followed by the need to unify a large number of backend behind composite explicit ops. The two exemplars of this are torch's cuDNN op, and batch norm with which consolidation is still ongoing by @andrewor14.

The conclusion was that for now we are okay with paying the backend tech debt (added ops and added coverage surface) for the sake of velocity.

@eqy eqy force-pushed the cudnn_sdp_backward branch from 68b1513 to d382f3d Compare April 5, 2024 00:11
@eqy
Copy link
Collaborator Author

eqy commented Apr 8, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnn_sdp_backward onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnn_sdp_backward && git pull --rebase)

const Tensor& dropoutoffset,
cudnnHandle_t& handle,
MHAParams& params) {
auto dtype = fe::DataType_t::HALF;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this support fp32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to these docs: https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-bprop not at the moment, but I'll ask the cuDNN team about the roadmap.

}
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe leave a comment at the top with some of the choices made about the graph construction

std::vector<int64_t>(v.strides().begin(), v.strides().end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this support arbitrary attention bias tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think newer versions (of cuDNN) do, and I will add that in a follow-up. Currently we might be limited by existing builds that use e.g., cuDNN 8.9.2.

std::vector<int64_t>(o.strides().begin(), o.strides().end())));
auto STATS = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("stats")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The softmaxstats is the logsumexp of the attention scores right? maybe softmaxstats is the nomenclature for cuDNN thats fine but if you wanted to align with the other kernels

.set_stride(
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe different name? to remove some confusion between the existing flash impl

auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_INTERNAL_ASSERT(!workspace_size || workspace_ptr.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI that has bitten some people TORCH_INTERNAL_ASSERT is only active when building with debug. This first assert probably makes sense to be internal, but do you think the execution of the graph should be a normal TORCH_CHECK ?

dv/*Tensor& dV*/,
philox_seed/*Tensor& dropoutseed*/,
philox_offset/*Tensor& dropoutoffset*/);
return std::make_tuple(dq, dk, dv);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::move?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, mostly just small nits, and if you have any perf numbers that would be great!

@eqy
Copy link
Collaborator Author

eqy commented Apr 17, 2024

Let me see if I can just rerun my existing forward script on forward + backward now...

@eqy
Copy link
Collaborator Author

eqy commented Apr 18, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnn_sdp_backward onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnn_sdp_backward && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Apr 27, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kit1980 kit1980 removed the Reverted label Apr 29, 2024
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
pytorch#113713
currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs

Will also collect benchmark data,

CC @drisspg

Pull Request resolved: pytorch#122510
Approved by: https://github.com/drisspg
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
#113713
currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs

Will also collect benchmark data,

CC @drisspg

Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
Pull Request resolved: #122510
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants