-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ROCm] Bug fix for flex attention configs avoiding ROCm path #140270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140270
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 1b16483 with merge base 62eea62 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Adding priority label to enable testing for triton 3.2/pytorch 2.6 |
| default_config = None | ||
|
|
||
| if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 | ||
| if head_dim <= 256 and torch.version.hip: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be cleaner to instead create a _get_config_rocm and we branch at call site, even if there is overlap
TBH this flow is already hard to grok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need the conditionalisation on head_dim at least
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any specific recommendations @drisspg, would like to get this bugfix in asap for triton 3.2 debugging, but the config selection in this file definitely does need refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should copy and paste the entire function and create two versions, 1 for rocm and for cuda 1. It feel to me although you need branching on head_dms that the branching should likely be different given the reduced smem constraints and increased warp size on AMD gpus.
I think lets just make two funcs and copy and paste as much as you need, until the dust settles some more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this has addressed @drisspg's concerns so I'm good to merge
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
There are failures associated, likely due to the new structure on NV side... let me sanity check and when we get a green signal I'll merge. cc: @bertmaher |
|
Rebasing failure seems unrelated |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
cb09132 to
1b16483
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#140270) Fixes pytorch#139755 pytorch#139621 Follow up fix to pytorch#139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations. Pull Request resolved: pytorch#140270 Approved by: https://github.com/bertmaher
Fixes #139755 #139621
Follow up fix to #139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov