KEMBAR78
[cuDNN][SDPA] Match `query`'s memory layout ordering for `output` in cuDNN SDPA by eqy · Pull Request #138354 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Oct 18, 2024

For #138340

We might consider more sophisticated logic here but the corresponding logic in other backends doesn't seem to do anything fancy for non BSHD/BHSD cases

res = at::empty({B, M, num_heads, Kv}, query.options());

ended up going with a more general approach to much more or less arbitrary layouts

cc @csarofeen @ptrblck @xwang233 @msaroufim @drisspg @mikaylagawarecki

@eqy eqy added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: cuda Related to torch.cuda, and CUDA support in general open source topic: bug fixes topic category module: multi-headed-attention labels Oct 18, 2024
@eqy eqy added this to the 2.5.1 milestone Oct 18, 2024
@eqy eqy requested a review from drisspg October 18, 2024 18:54
@eqy eqy requested a review from syed-ahmed as a code owner October 18, 2024 18:54
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138354

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 27360a9 with merge base 2ce2e4d (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy added the topic: not user facing topic category label Oct 18, 2024
@eqy
Copy link
Collaborator Author

eqy commented Oct 18, 2024

CC @ngimel @Skylion007
When discussing with @drisspg we realized that this fix might cause the grad output stride to no longer match the output's stride in a common cases. Normally this is not an issue but current cuDNN >= v9.5.0 has a bug where the grad output stride is incorrectly assumed to be the same as output stride, and the workaround for this means that if we fix this it may incur an extra .contiguous in the backward until we upgrade to the cuDNN release with this fix. (It's done and should be released soon)

@Skylion007
Copy link
Collaborator

CC @ngimel @Skylion007 When discussing with @drisspg we realized that this fix might cause the grad output stride to no longer match the output's stride in a common cases. Normally this is not an issue but current cuDNN >= v9.5.0 has a bug where the grad output stride is incorrectly assumed to be the same as output stride, and the workaround for this means that if we fix this it may incur an extra .contiguous in the backward until we upgrade to the cuDNN release with this fix. (It's done and should be released soon)

Not a problem for the backport as we use a much lower version of CUDNN though?

@eqy eqy changed the title [cuDNN][SDPA] Prefer BSHD by default for packed/non-contig in BHSD query [cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA Oct 18, 2024
@eqy eqy added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Oct 19, 2024
drisspg added a commit that referenced this pull request Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). 


Cc atalman

cc mikaylagawarecki

[ghstack-poisoned]
pytorchbot pushed a commit that referenced this pull request Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)
@pytorchmergebot
Copy link
Collaborator

Successfully rebased defaultbshd onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout defaultbshd && git pull --rebase)

@ngimel
Copy link
Collaborator

ngimel commented Nov 4, 2024

this fix might cause the grad output stride to no longer match the output's stride in a common cases
Out of curiosity, what would be a common case where gradOutput stride doesn't match output? This happens almost always today, because gradOutput would typically be permuted, and output is contiguous.

@eqy
Copy link
Collaborator Author

eqy commented Nov 4, 2024

The potential copy due to gradOutput vs. output stride issue should be resolved once 9.5.1 is used w/ wheels and we can gate that behind a cuDNN version check

@eqy
Copy link
Collaborator Author

eqy commented Nov 4, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg
Copy link
Contributor

drisspg commented Nov 5, 2024

@StrongerXi how?

@StrongerXi
Copy link
Contributor

@StrongerXi how?

I have no clue, I saw it on my PR which was odd, and then I saw that it's also failing on main (the above link).

@Skylion007
Copy link
Collaborator

Skylion007 commented Nov 20, 2024

The potential copy due to gradOutput vs. output stride issue should be resolved once 9.5.1 is used w/ wheels and we can gate that behind a cuDNN version check

Okay, CUDNN upgrade is available on CUDA 12.6 binaries. Feel free to add the gate in a new PR.

@drisspg drisspg mentioned this pull request Nov 20, 2024
9 tasks
pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2025
Update `cuDNN SDPA` meta registration to matching memory layout behavior in: #138354

Pull Request resolved: #148921
Approved by: https://github.com/drisspg, https://github.com/jbschlosser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support open source topic: bug fixes topic category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants