-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[FlexAttention][TF32] Handle uninitialized torch.backends.cuda.matmul.fp32_precision
#161102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161102
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4b6b15c with merge base 24e7f3c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
flex attention precision was forced to tf32 always after pytorch#161102. Need to surround python ternary with ().
|
This change doesn't work. I'm pretty sure it's the root cause for #161022 and it's also why the ROCm CI for MI200 runners started timing out because it's forcing the default to be tf32 for flex attention tests and tf32 isn't supported on MI200. test_flex_attention.py use to take less than 10 minutes, but this mistake makes it take 30+ minutes and the shard times out as a result. Issuing revert. Please rework the logic. You need to surround the new ternary operator in parentheses. Forward fix is submitted in #161465. |
|
@jeffdaily thanks for the forward fix |
PR #161102 caused tf32 to be the default precision for flex attention. This PR forward-fixes the broken logic and restores ROCm MI200 CI flex attention test. Pull Request resolved: #161465 Approved by: https://github.com/jeffdaily, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>
…l.fp32_precision` (pytorch#161102) For pytorch#161022 The warning says the old API will be deprecated in 2.9+ anyway, leaving it up to the author of pytorch#125888 to decide on initialization behavior then Pull Request resolved: pytorch#161102 Approved by: https://github.com/ngimel, https://github.com/drisspg, https://github.com/BoyuanFeng
PR pytorch#161102 caused tf32 to be the default precision for flex attention. This PR forward-fixes the broken logic and restores ROCm MI200 CI flex attention test. Pull Request resolved: pytorch#161465 Approved by: https://github.com/jeffdaily, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>
For #161022
The warning says the old API will be deprecated in 2.9+ anyway, leaving it up to the author of #125888 to decide on initialization behavior then
cc @ptrblck @msaroufim @jerryzh168 @zasdfgbnm @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Chillee @drisspg @yanboliang @BoyuanFeng