-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ROCm] Tune flex-attention and decode to num_stages=1 #139883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139883
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 48340e6 with merge base 314aa26 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
c932ec7
to
7f6b3c4
Compare
# On ROCm convert num_stages to 1 to avoid shmem issues | ||
configs = [(c[0], c[1], c[2], 1) for c in configs] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this config change guarded to only apply to ROCm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good find, I guarded it in flex_decode but I missed it out here... should be this:
if torch.version.hip:
configs = [(c[0], c[1], 1) for c in configs]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #139755 #139621 Follow up fix to #139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations. Pull Request resolved: #140270 Approved by: https://github.com/bertmaher
Fixes pytorch#139755 pytorch#139621 The new stream pipeliner on AMD triton backend enables num_stages to function equivalent to NV backend. This upgrade in triton 3.2 will cause OOM issues in flex attention due to num_stages=3 setting, we have tuned this to num_stages=1 which is the best setting for flash attention kernels and avoids the shmem issues. We will follow up this PR with some config tuning on AMD backend. Pull Request resolved: pytorch#139883 Approved by: https://github.com/bertmaher
…#140270) Fixes pytorch#139755 pytorch#139621 Follow up fix to pytorch#139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations. Pull Request resolved: pytorch#140270 Approved by: https://github.com/bertmaher
Fixes #139755 #139621
The new stream pipeliner on AMD triton backend enables num_stages to function equivalent to NV backend. This upgrade in triton 3.2 will cause OOM issues in flex attention due to num_stages=3 setting, we have tuned this to num_stages=1 which is the best setting for flash attention kernels and avoids the shmem issues.
We will follow up this PR with some config tuning on AMD backend.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov