KEMBAR78
[cuDNN][SDPA] Update cuDNN grad output layout check by eqy · Pull Request #141147 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Nov 20, 2024

Thanks to #137978 from @Skylion007 which bumps to cuDNN 9.5.1 the broken assumption of dO strides == O strides is fixed

Note that there is still the restriction that the innermost stride of the grad output is 1 (this is almost always guaranteed because this condition is required of the input tensors). The main exception would be in test code that does e.g., .sum().backward() which yields grad output tensors with strides [0, 0, 0, 0].

CC @drisspg

cc @csarofeen @ptrblck @xwang233 @drisspg @mikaylagawarecki

@eqy eqy added module: cudnn Related to torch.backends.cudnn, and CuDNN support open source topic: not user facing topic category module: multi-headed-attention labels Nov 20, 2024
@eqy eqy requested a review from syed-ahmed as a code owner November 20, 2024 18:49
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141147

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 183acff with merge base 78491d6 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1];
if (innermost_dO_stride != 1) {
TORCH_WARN_ONCE(
"cuDNN SDPA backward got grad_output with an innermost stride != 1 "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should even warn once here since a lot of people still use this for testing and this is kind of just like log spams them and there's nothing more efficient they can do.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we can tell them to update to a version of PyTorch that support cudnn 9.5.1

Copy link
Collaborator Author

@eqy eqy Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this check it's a known limitation due to the way the kernel is architected and is not planned to be fixed

@eqy eqy added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Nov 21, 2024
@eqy
Copy link
Collaborator Author

eqy commented Nov 25, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnn951dO onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnn951dO && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Nov 25, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / linux-focal-cuda11.8-py3.9-gcc9 / test (multigpu, 1, 1, lf.linux.g5.12xlarge.nvidia.gpu, oncall:distributed)

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator Author

eqy commented Nov 25, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnn951dO onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnn951dO && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Nov 26, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Thanks to pytorch#137978 from @Skylion007 which bumps to cuDNN 9.5.1 the broken assumption of dO strides == O strides is fixed

Note that there is still the restriction that the innermost stride of the grad output is 1 (this is almost always guaranteed because this condition is required of the input tensors). The main exception would be in test code that does e.g., `.sum().backward()` which yields grad output tensors with strides `[0, 0, 0, 0]`.

CC @drisspg

Pull Request resolved: pytorch#141147
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cudnn Related to torch.backends.cudnn, and CuDNN support open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants