-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[CPU] Support GQA for flash attention #157893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157893
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 289b47c with merge base b146ca7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
genernally OK, just simplify the test cases a little bit to remove the duplicated code.
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) | ||
@parametrize("n_heads", [[65, 5], [16, 4], [27, 1], [5, 1]]) | ||
@parametrize("train", [False, True]) | ||
def test_scaled_dot_product_fused_attention_gqa_vs_math_cpu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine this one with test_scaled_dot_product_fused_attention_mask_vs_math_cpu
to remove duplicated code.
### impls
def test_sdpa_vs_math_cpu_helper(...)
def test_scaled_dot_product_fused_attention_mask_vs_math_cpu()
test_sdpa_vs_math_cpu_helper(...)
def test_scaled_dot_product_fused_attention_gqa_vs_math_cpu()
test_sdpa_vs_math_cpu_helper(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, UT updated.
b7aa830
to
289b47c
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As many models require GQA, we support it in flash attention for CPU path. Approved by: https://github.com/mingfeima, https://github.com/jansel [ghstack-poisoned]
Summary: For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Test Plan: test_export Rollback Plan: Differential Revision: D78524147
Summary: For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Test Plan: test_export Rollback Plan: Reviewed By: angelayi Differential Revision: D78524147
Differential Revision: D78524147 For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Pull Request resolved: #158604 Approved by: https://github.com/angelayi, https://github.com/drisspg
Differential Revision: D78524147 For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Pull Request resolved: #158604 Approved by: https://github.com/angelayi, https://github.com/drisspg
As many models require GQA, we support it in flash attention for CPU path.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168