KEMBAR78
[Export] Fix SDPA decomposition by sidt-meta · Pull Request #135297 · pytorch/pytorch · GitHub
Skip to content

Conversation

@sidt-meta
Copy link
Contributor

Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the aten._scaled_dot_product_attention_math.default, which makes t.permute().continuous().permute() no longer necessary.

Test Plan: CI

Differential Revision: D62278378

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135297

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6b0f990 with merge base 183c32f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 6, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which aligns strides of the output with `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary, and is now removed.

Before removal, this decomp was causing issues with model export.

Test Plan: CI

Differential Revision: D62278378
@sidt-meta sidt-meta added the topic: not user facing topic category label Sep 6, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 6, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan: CI

Differential Revision: D62278378
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 7, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan: CI

Differential Revision: D62278378
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@sidt-meta
Copy link
Contributor Author

Updated exsting test_sdpa in test_decomp.py with to make sure the output's stride is as expected

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 10, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan:
Set up buck target for test_decomp.py

```
buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa
```
 {F1857936880}

----
CI signals

Reviewed By: StellarrZ

Differential Revision: D62278378
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 10, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 10, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan:
Set up buck target for test_decomp.py

```
buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa
```
 {F1857936880}

----
CI signals

Reviewed By: StellarrZ

Differential Revision: D62278378
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

sidt-meta added a commit to sidt-meta/pytorch that referenced this pull request Sep 11, 2024
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan:
Set up buck target for test_decomp.py

```
buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa
```
 {F1857936880}

----
CI signals

Reviewed By: StellarrZ

Differential Revision: D62278378
Summary:
Pull Request resolved: pytorch#135297

Updated stride from D62009189 which changes the output tensor stride flash attention. The decomposition needs to be updated have an output tensor with the same shape/stride in order for the subsequent view to be valid during torch.export

Test Plan:
Set up buck target for test_decomp.py

```
buck2 test mode/{opt,inplace} //caffe2/test:test_decomp -- test_sdpa
```
 {F1857936880}

----
CI signals

Reviewed By: StellarrZ

Differential Revision: D62278378
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62278378

@sidt-meta sidt-meta removed the request for review from StellarrZ September 11, 2024 17:36
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.

Test Plan: CI

Differential Revision: D62278378

Pull Request resolved: pytorch#135297
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants