KEMBAR78
Flex + NJT: cross attention support by jbschlosser · Pull Request #140723 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Nov 14, 2024

Stack from ghstack (oldest at bottom):

Fixes #140598

Allows ragged structures for query and key+value sequence lengths to differ (i.e. supports cross attention for Flex + NJT).

Technically, this is BC-breaking thanks to arg renaming and positional arg reordering in create_nested_block_mask(), but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140723

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 35d6575 with merge base 2675ef8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser jbschlosser added topic: improvements topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Nov 14, 2024
@jbschlosser jbschlosser requested review from cpuhrsch, drisspg and soulitzer and removed request for albanD November 14, 2024 16:59
Fixes #140598

Allows ragged structures for query and key+value to differ (i.e. supports cross attention for Flex + NJT).

Technically, this is BC-breaking thanks to positional arg reordering in `create_nested_block_mask()`, but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.

[ghstack-poisoned]
@jbschlosser jbschlosser added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Nov 14, 2024
Fixes #140598

Allows ragged structures for query and key+value sequence lengths to differ (i.e. supports cross attention for Flex + NJT).

Technically, this is BC-breaking thanks to arg renaming and positional arg reordering in `create_nested_block_mask()`, but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Nov 14, 2024
ghstack-source-id: 28240de
Pull Request resolved: #140723
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I would turn on ROCM jobs just to be sure there aren't any failures

@jbschlosser jbschlosser added the ciflow/rocm Trigger "default" config CI on ROCm label Nov 18, 2024
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 18, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

This PR (#140723) was merged in e80b1b2 but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

c-p-i-o pushed a commit to c-p-i-o/pytorch that referenced this pull request Nov 18, 2024
Fixes pytorch#140598

Allows ragged structures for query and key+value sequence lengths to differ (i.e. supports cross attention for Flex + NJT).

Technically, this is BC-breaking thanks to arg renaming and positional arg reordering in `create_nested_block_mask()`, but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.
Pull Request resolved: pytorch#140723
Approved by: https://github.com/drisspg
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#140598

Allows ragged structures for query and key+value sequence lengths to differ (i.e. supports cross attention for Flex + NJT).

Technically, this is BC-breaking thanks to arg renaming and positional arg reordering in `create_nested_block_mask()`, but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.
Pull Request resolved: pytorch#140723
Approved by: https://github.com/drisspg
@github-actions github-actions bot deleted the gh/jbschlosser/200/head branch December 19, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants