KEMBAR78
[Flex attention] Fix flex attention head broadcast by Isalia20 · Pull Request #163426 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented Sep 20, 2025

Fixes part of #163314

In particular bug: Bug 1: H=None Broadcasting Produces Incorrect Results

This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (mask[:, :, i]). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.

The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via sparse_idx_z = off_zq % 1 and sparse_idx_hq = off_hq % 1 and with a single Q tile q_start // SPARSE_Q_MULTIPLE = 0. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Chillee @drisspg @yanboliang @BoyuanFeng

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163426

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bfe0614 with merge base 3938175 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2025

To add the ciflow label ciflow/inductor please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@Skylion007
Copy link
Collaborator

Good backport candidate into our latest RC!

@Skylion007 Skylion007 modified the milestones: 2.10.0, 2.9.0 Sep 20, 2025
@albanD albanD removed their request for review September 22, 2025 15:18
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think this might change the shape of the return tensors right? that actually might be a good thing but would double check this doesn't break other blockmask tests

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 22, 2025
@Isalia20
Copy link
Collaborator Author

Yes it changes shape of the getitem on blockmask, updated other tests

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 22, 2025
@Isalia20
Copy link
Collaborator Author

I wonder if I should merge this, directly or add some warning first and then merge the change. It's somewhat bc breaking since the shapes of blockmask's returned have changed. @drisspg wdyt?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW the main usage for this AFAIK is in gpt-fast: https://github.com/meta-pytorch/gpt-fast/blob/6ecad9b5b6b987d17ac4303965545873d0192086/generate.py#L74

and this is using tensors as an index and so we keep the dim. The slicing operation feels a little weird tbh I would prefer if users manually edited the bits and then created a new BM from_kv_blocks and we wouldn't run into the not setting a score mod problem

This seems to be a pretty big foot gun for what is essentially syntactic sugar. So, IMO I think its okay to land and call it a bug fix.

Can you also update the PR w/ a description as to why this fixes the issue. AFAIK this looks like essentially a bad interaction where the sliced mask_mod w/ 1 less shape is interpretable even though it shouldn't be and so the kernel is reading bogus values

@Isalia20
Copy link
Collaborator Author

Updated description

@Isalia20
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Fixes part of pytorch#163314

In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**

This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.

The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error

Pull Request resolved: pytorch#163426
Approved by: https://github.com/drisspg
@Camyll
Copy link
Contributor

Camyll commented Oct 1, 2025

@pytorchbot cherry-pick --onto release/2.9 --c critical

pytorchbot pushed a commit that referenced this pull request Oct 1, 2025
Fixes part of #163314

In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**

This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.

The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error

Pull Request resolved: #163426
Approved by: https://github.com/drisspg

(cherry picked from commit 1a42656)
@pytorchbot
Copy link
Collaborator

Cherry picking #163426

The cherry pick PR is at #164368 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

Camyll pushed a commit that referenced this pull request Oct 1, 2025
[Flex attention] Fix flex attention head broadcast (#163426)

Fixes part of #163314

In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**

This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.

The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error

Pull Request resolved: #163426
Approved by: https://github.com/drisspg

(cherry picked from commit 1a42656)

Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: flex attention module: inductor open source release notes: nn release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants