KEMBAR78
[ROCm] Tune flex-attention and decode to num_stages=1 by jataylo · Pull Request #139883 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jataylo
Copy link
Collaborator

@jataylo jataylo commented Nov 6, 2024

Fixes #139755 #139621

The new stream pipeliner on AMD triton backend enables num_stages to function equivalent to NV backend. This upgrade in triton 3.2 will cause OOM issues in flex attention due to num_stages=3 setting, we have tuned this to num_stages=1 which is the best setting for flash attention kernels and avoids the shmem issues.

We will follow up this PR with some config tuning on AMD backend.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@jataylo jataylo requested review from drisspg and yanboliang November 6, 2024 13:08
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139883

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 48340e6 with merge base 314aa26 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm module: inductor module: rocm AMD GPU support for Pytorch labels Nov 6, 2024
@jataylo jataylo requested a review from bertmaher November 6, 2024 13:09
@jataylo jataylo added release notes: rocm mandatorylabel topic: not user facing topic category labels Nov 6, 2024
@jataylo
Copy link
Collaborator Author

jataylo commented Nov 7, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased flex-num-stages onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout flex-num-stages && git pull --rebase)

Comment on lines 873 to 875
# On ROCm convert num_stages to 1 to avoid shmem issues
configs = [(c[0], c[1], c[2], 1) for c in configs]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this config change guarded to only apply to ROCm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good find, I guarded it in flex_decode but I missed it out here... should be this:

  if torch.version.hip:
      configs = [(c[0], c[1], 1) for c in configs]

Copy link
Contributor

@bertmaher bertmaher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2024
Fixes #139755 #139621

Follow up fix to #139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations.

Pull Request resolved: #140270
Approved by: https://github.com/bertmaher
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#139755 pytorch#139621

The new stream pipeliner on AMD triton backend enables num_stages to function equivalent to NV backend. This upgrade in triton 3.2 will cause OOM issues in flex attention due to num_stages=3 setting, we have tuned this to num_stages=1 which is the best setting for flash attention kernels and avoids the shmem issues.

We will follow up this PR with some config tuning on AMD backend.

Pull Request resolved: pytorch#139883
Approved by: https://github.com/bertmaher
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…#140270)

Fixes pytorch#139755 pytorch#139621

Follow up fix to pytorch#139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations.

Pull Request resolved: pytorch#140270
Approved by: https://github.com/bertmaher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] [Triton 3.2] OOM shmem issues on Inductor tests with new SW pipelining

4 participants