-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode #16828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Re: the weirdness with torch compile on MI300x, I followed the suggestion of @robertgshaw2-redhat and re-ran everything inside the latest ![]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Overall looks pretty good! Left a few comments
S = apply_softcap(S, softcap) | ||
|
||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, | ||
S, float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add support for non-casual attention too? could be a future PR, but its useful for cascade attention and MLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I don't think that would be too hard
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> review comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> review comments + make unit tests pass Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> fix assert Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
19c5657
to
13c1c87
Compare
…d we have questions around cache keys Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Head branch was pushed to by a user without write access
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497) Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> * [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138) Signed-off-by: rzou <zou3519@gmail.com> * [Bugfix] Fixed prompt length for random dataset (vllm-project#17408) Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com> * [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [Misc] Fix ScalarType float4 naming (vllm-project#17690) Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * Fix `dockerfilegraph` pre-commit hook (vllm-project#17698) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446) Signed-off-by: Mengqing Cao <cmq0113@163.com> * [V1] Enable TPU V1 backend by default (vllm-project#17673) Signed-off-by: mgoin <mgoin64@gmail.com> * [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219) Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: jiang.li <jiang1.li@intel.com> * [v1] AttentionMetadata for each layer (vllm-project#17394) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * [Feat] Add deprecated=True to CLI args (vllm-project#17426) Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709) Signed-off-by: windsonsea <haifeng.yao@daocloud.io> * [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * [doc] Add RAG Integration example (vllm-project#17692) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Bugfix] Fix modality limits in vision language example (vllm-project#17721) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * Make right sidebar more readable in "Supported Models" (vllm-project#17723) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * [TPU] Increase block size and reset block shapes (vllm-project#16458) * [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839) Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> * [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> * [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: rzou <zou3519@gmail.com> Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: jiang.li <jiang1.li@intel.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: reidliu41 <reid201711@gmail.com> Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: Stan Wozniak <77159600+s3woz@users.noreply.github.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Richard Zou <zou3519@users.noreply.github.com> Co-authored-by: Mikhail Podvitskii <podvitskiymichael@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com> Co-authored-by: reidliu41 <reid201711@gmail.com> Co-authored-by: Jevin Jiang <jevin0change@gmail.com> Co-authored-by: d.transposed <damian.bogunowicz@gmail.com> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
…ll + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
…ll + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
…ll + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
In this PR we add:
triton_unified_attention
) that works likeflash_attn_varlen_func
and can handle arbitrary query length. The kernel does GQA "packing" along the query dimension to ensure the Tensor cores are well used.flash_attn_varlen_func
Best performance is obtained when using the jit cache decorator from
triton_dejavu
. In this code I'm using the jit cache decorator fromtriton_dejavu
package but if #16606 is merged we could use it directly from vLLM.Note that the unit tests don't currently pass when I enable the jit cache, but they all pass if it disabled. This is because we are testing different combinations of numbers of heads etc, which we assume to be constant in the decorator. We probably need to think of a good testing strategy for kernels with this decorator (cc @bringlein).
Initial benchmarking
Here are some initial benchmarking results on H100 for
llama3.1-8b
using:Note that with these changes, the Triton backend significantly outperforms FlashAttention backend on an H100 GPU for this workloads.
Correctness
@bringlein the correctness check only looks good if I use branch
tpa-grid-copy
from triton-dejavu, if I use main it fails. you should be able to reproduce on H100.Further benchmarking
I've run the following scenario across both H100 and MI300x using different backends from main, as well as using the changes from this PR:
Here are the results:

Main takeaways:
cc @robertgshaw2-redhat @tlrmchlsmth @SageMoore @bringlein @jvlunteren @LucasWilkinson