-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. #14431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
please publish accuracy test as well. |
skip_decode=True, | ||
) | ||
|
||
block_size = value_cache.shape[3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to understand it right, should be a return after call of context_attention_fwd? otherwise for max_query_len > 1 you are calling two kernels that might compute the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, we added this option to the context_attention_fwd
on main
, which if enabled will skip the sequences in the batch with query_length=1
. We then launch another kernel concurrently to handle the ones that were skipped.
Accuracy results Using V1
Using V1
cc @maleksan85 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Thanks for the contribution!
Here are the llm_eval results from an MI300X machine. Results look good
|
@tdoublep you probably have already seen, if not then: https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#triton-kernel-performance-optimization |
@tdoublep do you understand the increase in mean TTFT versus main's |
One thing I recommend for these types of performance comparisons is adding |
@tlrmchlsmth hmm good catch, I hadn't noticed that. Will have another look.
Makes sense, will re-run with that enabled. |
@tlrmchlsmth I've re-run everything using benchmark command:
ROCmAttentionBackend @
ROCmAttentionBackend @
ROCmAttentionBackend @
ROCmAttentionBackend @
ROCmAttentionBackend @
ROCmAttentionBackend @
I think it looks OK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdoublep thanks for rerunning those tests!
Changes look good and the performance optimization makes sense
@tlrmchlsmth The multi-modal test that is failing does not look related to these changes. |
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Signed-off-by: Richard Liu <ricliu@google.com>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
…r handle GQA. (vllm-project#14431) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
TLDR: This PR adds some further optimizations to
chunked_prefill_paged_decode
op to better handle models with GQA. Serving benchmarks using V1 indicate that with these changes, we see a 25% improvement in throughput forllama3.1-8b
on an H100 vs. the current Triton implementation. With these changes, the throughput of the Triton implementation is only 8% worse than the V1 CUDA backend (FlashAttention).Using
FlashAttentionBackend
from main on H100:Using
ROCmAttentionBackend
from main on H100:And finally, using
ROCmAttentionBackend
from this PR on H100:cc @SageMoore @maleksan85