KEMBAR78
Make onnx export SDPA match aten behavior by IlyasMoutawwakil · Pull Request #159973 · pytorch/pytorch · GitHub
Skip to content

Conversation

@IlyasMoutawwakil
Copy link
Contributor

@IlyasMoutawwakil IlyasMoutawwakil commented Aug 6, 2025

This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.
@justinchuby

import onnxruntime as ort
import torch


class ScaledDotProductAttention(torch.nn.Module):
    def forward(self, query, key, value, attn_mask):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)


model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool()  # boolean mask for attention
attn_mask[0, 0, 0, :] = False  # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)

torch.onnx.export(
    model,
    (query, key, value, attn_mask),
    "scaled_dot_product_attention.onnx",
    input_names=["query", "key", "value", "attn_mask"],
    output_names=["output"],
    dynamo=false, # or True,
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")

np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]

torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)

fails the assertion because the ort model outputs nans.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Aug 6, 2025
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 6, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: IlyasMoutawwakil / name: Ilyas Moutawwakil (36fde91, 8f7d6ce)

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159973

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8f7d6ce with merge base 4c01991 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you sign CLA?

@titaiwangms titaiwangms added the topic: bug fixes topic category label Aug 6, 2025
@titaiwangms titaiwangms requested a review from xadupre August 6, 2025 18:43
@titaiwangms
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.
@justinchuby

```python
import onnxruntime as ort
import torch

class ScaledDotProductAttention(torch.nn.Module):
    def forward(self, query, key, value, attn_mask):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)

model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool()  # boolean mask for attention
attn_mask[0, 0, 0, :] = False  # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)

torch.onnx.export(
    model,
    (query, key, value, attn_mask),
    "scaled_dot_product_attention.onnx",
    input_names=["query", "key", "value", "attn_mask"],
    output_names=["output"],
    dynamo=false, # or True,
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")

np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]

torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)
```
fails the assertion because the ort model outputs nans.
Pull Request resolved: pytorch#159973
Approved by: https://github.com/xadupre, https://github.com/titaiwangms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants