KEMBAR78
Address NaNs if SDPA is called with all values masked from query by jhavukainen · Pull Request #157727 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jhavukainen
Copy link
Collaborator

Fixes #156707

Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation.

These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157727

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7394609 with merge base 8a47f9d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Jul 7, 2025
@soulitzer soulitzer requested review from drisspg and jbschlosser July 7, 2025 19:21
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2025
@drisspg
Copy link
Contributor

drisspg commented Jul 7, 2025

@malfet I dont really know enough about the mps implementation to tell if this is the correct way to do this, but this seems to treat any Inf as the same? This is how I updated the existing flash implementations: #131863

@jhavukainen
Copy link
Collaborator Author

I see what you mean. Yeah this version would end up catching both negative and positive infinities instead of just negative ones so I'll need to change that. I'll push an update today to fix that.

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I wonder what are the performance implications of adding those checks?

@jhavukainen
Copy link
Collaborator Author

@malfet Yeah that's a fair question. I ran a couple of naive benchmarks to get an idea of the effect this would have. For smaller shapes it does cause a measurable difference but with more relevant sizes it disappears within the measurement variability quite quickly, so overall I'd estimate the check doesn't cost too much for the average use case. Here's the script I used with an M2 Max device:

import torch
import numpy as np
import time

L=2048
S=512
NH=128
HS=1024
dtype=torch.float32

q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")

op = torch.nn.functional.scaled_dot_product_attention

num_warmup = 5
n_stats = 50

for _ in range(num_warmup):
    out = op(q, k, v)
    torch.mps.synchronize()

stats = []
for i in range(n_stats):
    start = time.perf_counter_ns()
    out = op(q, k, v)
    torch.mps.synchronize()
    end = time.perf_counter_ns()
    stats.append((end-start)/1e3)

print(f"Took {np.mean(stats):0.2f}\u00B1{np.std(stats):0.2f} \u03bcs per iter")

And the results for two sets of shapes:

L=1, S=72, NH=32, HS=128
# Before change:
Took 310.55±37.74 μs per iter
# After change:
Took 368.98±35.13 μs per iter
L=2048, S=512, NH=128, HS=1024
# Before change:
Took 52839.62±147.31 μs per iter
# After change:
Took 52825.06±132.55 μs per iter

With the GPT2 train benchmark from torchbench I'm getting

Before change:
----------------------------------------------- benchmark 'hub': 1 tests -----------------------------------------------
Name (time in ms)                Min       Max      Mean  StdDev    Median     IQR  Outliers     OPS  Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------
test_train[hf_GPT2-mps]     398.4452  407.1690  400.6670  3.6971  398.8504  3.3065       1;1  2.4958       5           1
------------------------------------------------------------------------------------------------------------------------

After change:
----------------------------------------------- benchmark 'hub': 1 tests -----------------------------------------------
Name (time in ms)                Min       Max      Mean  StdDev    Median     IQR  Outliers     OPS  Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------
test_train[hf_GPT2-mps]     395.7955  400.3153  398.8589  1.7881  399.2772  1.8033       1;0  2.5072       5           1
------------------------------------------------------------------------------------------------------------------------

@jhavukainen
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MPS SDPA returns NaN when attention mask blocks all rows

6 participants