-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Export] Fix SDPA decomposition #135297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Export] Fix SDPA decomposition #135297
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135297
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6b0f990 with merge base 183c32f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
1055205 to
10238c8
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
10238c8 to
a86a8ba
Compare
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which aligns strides of the output with `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary, and is now removed. Before removal, this decomp was causing issues with model export. Test Plan: CI Differential Revision: D62278378
a86a8ba to
86c21ea
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
86c21ea to
2420162
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
2420162 to
ef89466
Compare
ef89466 to
066ea59
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: CI Differential Revision: D62278378
066ea59 to
3396f61
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
3396f61 to
a550ed3
Compare
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: CI Differential Revision: D62278378
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
302be07 to
85a47fe
Compare
|
Updated exsting test_sdpa in test_decomp.py with to make sure the output's stride is as expected |
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: Set up buck target for test_decomp.py ``` buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa ``` {F1857936880} ---- CI signals Reviewed By: StellarrZ Differential Revision: D62278378
85a47fe to
50076bb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
50076bb to
cd99c02
Compare
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: Set up buck target for test_decomp.py ``` buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa ``` {F1857936880} ---- CI signals Reviewed By: StellarrZ Differential Revision: D62278378
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
cd99c02 to
c2cbed0
Compare
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: Set up buck target for test_decomp.py ``` buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa ``` {F1857936880} ---- CI signals Reviewed By: StellarrZ Differential Revision: D62278378
Summary: Pull Request resolved: pytorch#135297 Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export Test Plan: Set up buck target for test_decomp.py ``` buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa ``` {F1857936880} ---- CI signals Reviewed By: StellarrZ Differential Revision: D62278378
|
This pull request was exported from Phabricator. Differential Revision: D62278378 |
c2cbed0 to
6b0f990
Compare
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary. Test Plan: CI Differential Revision: D62278378 Pull Request resolved: pytorch#135297 Approved by: https://github.com/drisspg
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the
aten._scaled_dot_product_attention_math.default, which makest.permute().continuous().permute()no longer necessary.Test Plan: CI
Differential Revision: D62278378