-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Flex attention] Fix flex attention head broadcast #163426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163426
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bfe0614 with merge base 3938175 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
Good backport candidate into our latest RC! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I think this might change the shape of the return tensors right? that actually might be a good thing but would double check this doesn't break other blockmask tests
|
Yes it changes shape of the getitem on blockmask, updated other tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
|
I wonder if I should merge this, directly or add some warning first and then merge the change. It's somewhat bc breaking since the shapes of blockmask's returned have changed. @drisspg wdyt? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW the main usage for this AFAIK is in gpt-fast: https://github.com/meta-pytorch/gpt-fast/blob/6ecad9b5b6b987d17ac4303965545873d0192086/generate.py#L74
and this is using tensors as an index and so we keep the dim. The slicing operation feels a little weird tbh I would prefer if users manually edited the bits and then created a new BM from_kv_blocks and we wouldn't run into the not setting a score mod problem
This seems to be a pretty big foot gun for what is essentially syntactic sugar. So, IMO I think its okay to land and call it a bug fix.
Can you also update the PR w/ a description as to why this fixes the issue. AFAIK this looks like essentially a bad interaction where the sliced mask_mod w/ 1 less shape is interpretable even though it shouldn't be and so the kernel is reading bogus values
|
Updated description |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes part of pytorch#163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: pytorch#163426 Approved by: https://github.com/drisspg
|
@pytorchbot cherry-pick --onto release/2.9 --c critical |
Fixes part of #163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: #163426 Approved by: https://github.com/drisspg (cherry picked from commit 1a42656)
Cherry picking #163426The cherry pick PR is at #164368 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
[Flex attention] Fix flex attention head broadcast (#163426) Fixes part of #163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: #163426 Approved by: https://github.com/drisspg (cherry picked from commit 1a42656) Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
Fixes part of #163314
In particular bug: Bug 1: H=None Broadcasting Produces Incorrect Results
This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (mask[:, :, i]). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.
The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via
sparse_idx_z = off_zq % 1andsparse_idx_hq = off_hq % 1and with a single Q tileq_start // SPARSE_Q_MULTIPLE = 0. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent errorcc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Chillee @drisspg @yanboliang @BoyuanFeng