KEMBAR78
[BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. by Skylion007 · Pull Request #133686 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Skylion007
Copy link
Collaborator

@Skylion007 Skylion007 commented Aug 16, 2024

  • This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly.
  • Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels.
  • Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time.
  • These kernels and settings will be needed for Flash Attention 3 whenever we add that too.

Fixes #133695

cc @ptrblck @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133686

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 7a2ba33 with merge base 375921b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Skylion007 Skylion007 added module: cuda Related to torch.cuda, and CUDA support in general better-engineering Relatively self-contained tasks for better engineering contributors release notes: cuda release notes category labels Aug 16, 2024
@Skylion007 Skylion007 force-pushed the skylion007/cutlass-enable-extended-mma-shapes branch from f174975 to 8b35d14 Compare August 16, 2024 14:46
@Skylion007 Skylion007 marked this pull request as ready for review August 16, 2024 14:54
@Skylion007 Skylion007 force-pushed the skylion007/cutlass-enable-extended-mma-shapes branch from 8b35d14 to 13f1918 Compare August 16, 2024 15:10
@Skylion007 Skylion007 changed the title [BE]: Enable extended MMA shapes CUTLASS (3.5+ feature) [BE][CUDA][Bugfix]: Enable MMA shapes in CUTLASS. Aug 16, 2024
@Skylion007 Skylion007 changed the title [BE][CUDA][Bugfix]: Enable MMA shapes in CUTLASS. [BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. Aug 16, 2024
@stas00
Copy link
Contributor

stas00 commented Aug 16, 2024

s/SPDA/SDPA/ in the OP?

@Skylion007
Copy link
Collaborator Author

s/SPDA/SDPA/ in the OP?

Fixed

@Skylion007 Skylion007 force-pushed the skylion007/cutlass-enable-extended-mma-shapes branch from d3d86c3 to 13f1918 Compare August 16, 2024 18:16
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 16, 2024
@Skylion007
Copy link
Collaborator Author

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 28, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased skylion007/cutlass-enable-extended-mma-shapes onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout skylion007/cutlass-enable-extended-mma-shapes && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the skylion007/cutlass-enable-extended-mma-shapes branch from 13f1918 to 7a2ba33 Compare September 28, 2024 16:02
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor-periodic / cuda12.1-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100)

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Sep 28, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

better-engineering Relatively self-contained tasks for better engineering contributors ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general module: inductor open source release notes: cuda release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUTLASS does not build with extended MMA shape kernels in PyTorch

6 participants