KEMBAR78
[export] Modify SDPA decomposition to decompose _scaled_dot_product_flash_attention_for_cpu by larryliu0820 · Pull Request #117097 · pytorch/pytorch · GitHub
Skip to content

Conversation

larryliu0820
Copy link
Contributor

@larryliu0820 larryliu0820 commented Jan 10, 2024

Stack from ghstack (oldest at bottom):

Summary: As titled. #115913 added
_scaled_dot_product_flash_attention_for_cpu and the export result of
scaled_dot_product_attention includes this op. Adding this
decomposition so that it's being decomposed the same way as
_scaled_dot_product_attention_math.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

…lash_attention_for_cpu

Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117097

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6a34f62 with merge base 19e93b8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…t_product_flash_attention_for_cpu"

Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
larryliu0820 added a commit that referenced this pull request Jan 10, 2024
…lash_attention_for_cpu

Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan: python test/test_decomp.py -k test_aten_core_operators

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8651762
Pull Request resolved: #117097
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to keep the decomp of _scaled_dot_product_flash_attention?

@kimishpatel
Copy link
Contributor

I have been wanting to fix this in a more correct way which probably is non-trivial. Correct way is really to define decomp for aten sdpa op. I remember discussing this with you, but I still think we should look into decomposing aten sdpa.

One not so nice executorch specific workaround is inside to_edge, where before calling run_decomposition, we manually replace instances of aten sdpa with its decomp version.

@kimishpatel
Copy link
Contributor

Don't we want to keep the decomp of _scaled_dot_product_flash_attention?

why flash attention?

@lezcano
Copy link
Collaborator

lezcano commented Jan 10, 2024

This PR removes one decomposition and registers a decomposition in one of the functions inside it. Why don't we keep the original decomposition in terms of this new decomposition?

@kimishpatel
Copy link
Contributor

This PR removes one decomposition and registers a decomposition in one of the functions inside it. Why don't we keep the original decomposition in terms of this new decomposition?

Oh I see. Sorry didnt follow the first time around. Yeah makes sense. Basically you are suggesting to keep definition for both decomp and one of it just calls the other one. @larryliu0820 ?

@lezcano
Copy link
Collaborator

lezcano commented Jan 10, 2024

Yep, although looking at the removed code, the previous decomposition was completely wrong, with variables like

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)

that are never filled up, so it may be alright to straight up remove it.

@larryliu0820
Copy link
Contributor Author

Yep, although looking at the removed code, the previous decomposition was completely wrong, with variables like

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)

that are never filled up, so it may be alright to straight up remove it.

Yeah the previous decomposition was pretty awkward because _scaled_dot_product_flash_attention is returning way more things than _scaled_dot_product_attention_math. The current decomp makes much more sense.

@larryliu0820
Copy link
Contributor Author

@pytorchbot merge

@lezcano
Copy link
Collaborator

lezcano commented Jan 10, 2024

@larryliu0820 can you please deactivate the test for CUDA?

@larryliu0820
Copy link
Contributor Author

@larryliu0820 can you please deactivate the test for CUDA?

oh sorry I'm on it

…t_product_flash_attention_for_cpu"

Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
larryliu0820 added a commit that referenced this pull request Jan 10, 2024
…lash_attention_for_cpu

Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan: python test/test_decomp.py -k test_aten_core_operators

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 45bd1d2
Pull Request resolved: #117097
@larryliu0820
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 10, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@larryliu0820
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

larryliu0820 added a commit that referenced this pull request Jan 12, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
larryliu0820 added a commit that referenced this pull request Jan 12, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
larryliu0820 added a commit that referenced this pull request Jan 12, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
larryliu0820 added a commit that referenced this pull request Jan 12, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8c7f1ab
Pull Request resolved: #117390
pytorchmergebot pushed a commit that referenced this pull request Jan 14, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #117390
Approved by: https://github.com/drisspg
@facebook-github-bot facebook-github-bot deleted the gh/larryliu0820/43/head branch January 14, 2024 15:23
suo added a commit that referenced this pull request Jan 16, 2024
…ult"

Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
suo added a commit that referenced this pull request Jan 16, 2024
Summary:

A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Test Plan: python test/export/test_export.py -k
test_scaled_dot_product_attention

Reviewers:

Subscribers:

Tasks:

Tags:

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
suo added a commit that referenced this pull request Jan 16, 2024
A follow up for #117097. In that PR I didn't add
`_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition
table. This PR does that and also add a unit test.

Pull Request resolved: #117390
Approved by: https://github.com/drisspg


Internal:
<< DO NOT EDIT BELOW THIS LINE >>

Differential Revision: [D52788012](https://our.internmc.facebook.com/intern/diff/D52788012/)
ghstack-source-id: 212131226
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants