-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[cpu] add sdpa choice and UT #105131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cpu] add sdpa choice and UT #105131
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105131
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 7be93a6 with merge base 600f9ef ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Feature RFC: pytorch/rfcs#56. Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen. ## Performance of the stack ### NanoGPT's SDPA kernel Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket. Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64. Machine: SPR. | Dtype | Causal | Mode | SDPA | Time (ms per iter) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | FALSE | Inference | Unfused | 3.081 | | | | | | Flash attention | 1.665 | **1.85045** | | float32 | TRUE | Inference | Unfused | 3.463 | | | | | | Flash attention | 1.662 | **2.083634**| | bfloat16 | FALSE | Inference | Unfused | 1.203 | | | | | | Flash attention | 1.154 | **1.042461**| | bfloat16 | TRUE | Inference | Unfused | 1.543 | | | | | | Flash attention | 1.154 | **1.337088**| | float32 | FALSE | Training | Unfused | 54.938 | | | | | | Flash attention | 23.029 | **2.385601**| | float32 | TRUE | Training | Unfused | 58.266 | | | | | | Flash attention | 17.835 | **3.266947**| | bfloat16 | FALSE | Training | Unfused | 18.924 | | | | | | Flash attention | 18.886 | **1.002012**| | bfloat16 | TRUE | Training | Unfused | 21.08 | | | | | | Flash attention | 14.172 | **1.48744** | ### Stable Diffusion Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md). Mode: Inference; Machine: SPR. | Dtype | SDPA | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | Unfused | 1.63 | | 1139 | | | | Flash attention | 1.983 | 1.216564 | 547.488 | **2.080411**| | bfloat16 | Flash attention in IPEX | 4.784 | | 429.051 | | | | Flash attention | 4.857 | 1.015259 | 408.823 | **1.049479**| ### LLM models of Torchbench Dtype: float32; Mode: Inference, single socket; Machine: CPX. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024** hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841** hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636** llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451** Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324** hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842** hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355** llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907** cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Feature RFC: pytorch/rfcs#56. Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen. ## Performance of the stack ### NanoGPT's SDPA kernel Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket. Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64. Machine: SPR. | Dtype | Causal | Mode | SDPA | Time (ms per iter) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | FALSE | Inference | Unfused | 3.081 | | | | | | Flash attention | 1.665 | **1.85045** | | float32 | TRUE | Inference | Unfused | 3.463 | | | | | | Flash attention | 1.662 | **2.083634**| | bfloat16 | FALSE | Inference | Unfused | 1.203 | | | | | | Flash attention | 1.154 | **1.042461**| | bfloat16 | TRUE | Inference | Unfused | 1.543 | | | | | | Flash attention | 1.154 | **1.337088**| | float32 | FALSE | Training | Unfused | 54.938 | | | | | | Flash attention | 23.029 | **2.385601**| | float32 | TRUE | Training | Unfused | 58.266 | | | | | | Flash attention | 17.835 | **3.266947**| | bfloat16 | FALSE | Training | Unfused | 18.924 | | | | | | Flash attention | 18.886 | **1.002012**| | bfloat16 | TRUE | Training | Unfused | 21.08 | | | | | | Flash attention | 14.172 | **1.48744** | ### Stable Diffusion Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md). Mode: Inference; Machine: SPR. | Dtype | SDPA | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | Unfused | 1.63 | | 1139 | | | | Flash attention | 1.983 | 1.216564 | 547.488 | **2.080411**| | bfloat16 | Flash attention in IPEX | 4.784 | | 429.051 | | | | Flash attention | 4.857 | 1.015259 | 408.823 | **1.049479**| ### LLM models of Torchbench Dtype: float32; Mode: Inference, single socket; Machine: CPX. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024** hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841** hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636** llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451** Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324** hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842** hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355** llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907** cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Feature RFC: pytorch/rfcs#56. Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen. ## Performance of the stack ### NanoGPT's SDPA kernel Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket. Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64. Machine: SPR. | Dtype | Causal | Mode | SDPA | Time (ms per iter) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | FALSE | Inference | Unfused | 3.081 | | | | | | Flash attention | 1.665 | **1.85045** | | float32 | TRUE | Inference | Unfused | 3.463 | | | | | | Flash attention | 1.662 | **2.083634**| | bfloat16 | FALSE | Inference | Unfused | 1.203 | | | | | | Flash attention | 1.154 | **1.042461**| | bfloat16 | TRUE | Inference | Unfused | 1.543 | | | | | | Flash attention | 1.154 | **1.337088**| | float32 | FALSE | Training | Unfused | 54.938 | | | | | | Flash attention | 23.029 | **2.385601**| | float32 | TRUE | Training | Unfused | 58.266 | | | | | | Flash attention | 17.835 | **3.266947**| | bfloat16 | FALSE | Training | Unfused | 18.924 | | | | | | Flash attention | 18.886 | **1.002012**| | bfloat16 | TRUE | Training | Unfused | 21.08 | | | | | | Flash attention | 14.172 | **1.48744** | ### Stable Diffusion Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md). Mode: Inference; Machine: SPR. | Dtype | SDPA | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup | | -------- | -------- | ------- | ------- | ------- | ------- | | float32 | Unfused | 1.63 | | 1139 | | | | Flash attention | 1.983 | 1.216564 | 547.488 | **2.080411**| | bfloat16 | Flash attention in IPEX | 4.784 | | 429.051 | | | | Flash attention | 4.857 | 1.015259 | 408.823 | **1.049479**| ### LLM models of Torchbench Dtype: float32; Mode: Inference, single socket; Machine: CPX. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024** hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841** hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636** llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451** Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR. Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new) -- | -- | -- | -- | -- hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324** hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842** hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355** llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907** cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) [ghstack-poisoned]
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Pull Request resolved: #108180 `scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. ghstack-source-id: 199140502 @exported-using-ghexport Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/)
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Pull Request resolved: #108180 `scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. ghstack-source-id: 199155539 @exported-using-ghexport Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/)
`scaled_dot_product_attention` used to be decomposed in pre-autograd, given that it calls `_scaled_dot_product_attention_math` and `_scaled_dot_product_attention_math` only has a `CompositeImplicitAutograd` kernel. As a result it's decomposed into ops with finer granularity. However recent PRs (#103826 #105131) added new logic in `scaled_dot_product_attention` and now it calls `_scaled_dot_product_flash_attention` which contains a CPU kernel. This results in `_scaled_dot_product_flash_attention` showing up in `torch.export()`. This PR adds a decomposition that ensures `scaled_dot_product_attention` is still being decomposed the same way as before, i.e., going through `_scaled_dot_product_attention_math`. Notice that this decomp rule should be excluded by inductor. Differential Revision: [D48762000](https://our.internmc.facebook.com/intern/diff/D48762000/) Pull Request resolved: #108180 Approved by: https://github.com/SherlockNoMad
Stack from ghstack (oldest at bottom):
Feature RFC: pytorch/rfcs#56.
Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.
Performance of the stack
NanoGPT's SDPA kernel
Using benchmark repo, with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.
Stable Diffusion
Following model's BKM.
Mode: Inference; Machine: SPR.
LLM models of Torchbench
Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov