-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Address NaNs if SDPA is called with all values masked from query #157727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…lues query can attend to. Regression test for the issue.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157727
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7394609 with merge base 8a47f9d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I see what you mean. Yeah this version would end up catching both negative and positive infinities instead of just negative ones so I'll need to change that. I'll push an update today to fix that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I wonder what are the performance implications of adding those checks?
|
@malfet Yeah that's a fair question. I ran a couple of naive benchmarks to get an idea of the effect this would have. For smaller shapes it does cause a measurable difference but with more relevant sizes it disappears within the measurement variability quite quickly, so overall I'd estimate the check doesn't cost too much for the average use case. Here's the script I used with an M2 Max device: And the results for two sets of shapes: With the GPT2 train benchmark from torchbench I'm getting |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #156707
Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation.
These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue.