KEMBAR78
[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. by tdoublep · Pull Request #14431 · vllm-project/vllm · GitHub
Skip to content

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Mar 7, 2025

TLDR: This PR adds some further optimizations to chunked_prefill_paged_decode op to better handle models with GQA. Serving benchmarks using V1 indicate that with these changes, we see a 25% improvement in throughput for llama3.1-8b on an H100 vs. the current Triton implementation. With these changes, the throughput of the Triton implementation is only 8% worse than the V1 CUDA backend (FlashAttention).

Using FlashAttentionBackend from main on H100:

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  20.77     
Total input tokens:                      215196    
Total generated tokens:                  198001    
Request throughput (req/s):              48.15     
Output token throughput (tok/s):         9534.13   
Total Token throughput (tok/s):          19896.23  
---------------Time to First Token----------------
Mean TTFT (ms):                          3563.95   
Median TTFT (ms):                        3386.57   
P99 TTFT (ms):                           6444.51   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          84.11     
Median TPOT (ms):                        48.47     
P99 TPOT (ms):                           213.98    
---------------Inter-token Latency----------------
Mean ITL (ms):                           37.38     
Median ITL (ms):                         23.74     
P99 ITL (ms):                            216.44    
==================================================

Using ROCmAttentionBackend from main on H100:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  28.22     
Total input tokens:                      215196    
Total generated tokens:                  197281    
Request throughput (req/s):              35.44     
Output token throughput (tok/s):         6991.12   
Total Token throughput (tok/s):          14617.11  
---------------Time to First Token----------------
Mean TTFT (ms):                          3483.50   
Median TTFT (ms):                        3607.34   
P99 TTFT (ms):                           6616.79   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          99.14     
Median TPOT (ms):                        63.35     
P99 TPOT (ms):                           246.93    
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.02     
Median ITL (ms):                         40.98     
P99 ITL (ms):                            253.97    
==================================================

And finally, using ROCmAttentionBackend from this PR on H100:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.63     
Total input tokens:                      215196    
Total generated tokens:                  197122    
Request throughput (req/s):              44.19     
Output token throughput (tok/s):         8711.51   
Total Token throughput (tok/s):          18221.76  
---------------Time to First Token----------------
Mean TTFT (ms):                          3947.90   
Median TTFT (ms):                        3809.46   
P99 TTFT (ms):                           7080.03   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          89.17     
Median TPOT (ms):                        50.22     
P99 TPOT (ms):                           230.38    
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.51     
Median ITL (ms):                         24.95     
P99 ITL (ms):                            233.80    
==================================================

cc @SageMoore @maleksan85

Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@github-actions
Copy link

github-actions bot commented Mar 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@tdoublep tdoublep changed the title Further optimizations to 2D kernel to better handle GQA. [Kernel] [V1] Further optimizations to 2D kernel to better handle GQA. Mar 7, 2025
@tdoublep tdoublep changed the title [Kernel] [V1] Further optimizations to 2D kernel to better handle GQA. [Kernel] [V1] Further optimizations to ROCmAttentionBackend to better handle GQA. Mar 7, 2025
@tdoublep tdoublep changed the title [Kernel] [V1] Further optimizations to ROCmAttentionBackend to better handle GQA. [Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. Mar 7, 2025
@maleksan85
Copy link
Contributor

please publish accuracy test as well.

skip_decode=True,
)

block_size = value_cache.shape[3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to understand it right, should be a return after call of context_attention_fwd? otherwise for max_query_len > 1 you are calling two kernels that might compute the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we added this option to the context_attention_fwd on main, which if enabled will skip the sequences in the batch with query_length=1. We then launch another kernel concurrently to handle the ones that were skipped.

@tdoublep
Copy link
Member Author

tdoublep commented Mar 7, 2025

Accuracy results

Using V1 FlashAttentionBackend on main:

$ VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=/models/llama3.1-8b/instruct/ --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.792|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.770|±  |0.0188|

Using V1 ROCmAttentionBackend from this branch:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.788|±  |0.0183|
|     |       |strict-match    |     5|exact_match|↑  |0.768|±  |0.0189|

cc @maleksan85

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Thanks for the contribution!

@SageMoore
Copy link
Contributor

Here are the llm_eval results from an MI300X machine. Results look good

vllm (pretrained=meta-llama/meta-llama-3.1-8b-instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.796|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.778|±  |0.0186|

@maleksan85
Copy link
Contributor

@tlrmchlsmth
Copy link
Member

@tdoublep do you understand the increase in mean TTFT versus main's RoCMAttentionBackend?

@tlrmchlsmth
Copy link
Member

One thing I recommend for these types of performance comparisons is adding --ignore-eos, as the different number of generated output tokens in these three cases can make performance difficult to reason about.

@tdoublep
Copy link
Member Author

do you understand the increase in mean TTFT versus main's RoCMAttentionBackend?

@tlrmchlsmth hmm good catch, I hadn't noticed that. Will have another look.

One thing I recommend for these types of performance comparisons is adding --ignore-eos

Makes sense, will re-run with that enabled.

@tdoublep
Copy link
Member Author

@tlrmchlsmth I've re-run everything using --ignore-eos with multiple repetitions. Note that I stop + restart the server between repetitions, otherwise APC changes the results. I can't reproduce the increase in TTFT vs. main. I guess it was either related to --ignore-eos somehow (although I can't really see why this would affect TTFT, unless some weird scheduler dynamics), or a blip due to something else running on my system.

benchmark command:

python benchmarks/benchmark_serving.py \
    --model /models/llama3.1-8b/instruct/ \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --ignore-eos

ROCmAttentionBackend @ main [repetition 1]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  28.55     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              35.03     
Output token throughput (tok/s):         6947.10   
Total Token throughput (tok/s):          14484.48  
---------------Time to First Token----------------
Mean TTFT (ms):                          3656.09   
Median TTFT (ms):                        3457.98   
P99 TTFT (ms):                           6960.72   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          101.13    
Median TPOT (ms):                        62.61     
P99 TPOT (ms):                           249.56    
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.34     
Median ITL (ms):                         40.91     
P99 ITL (ms):                            256.30    
==================================================

ROCmAttentionBackend @ main [repetition 2]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  28.61     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              34.96     
Output token throughput (tok/s):         6933.30   
Total Token throughput (tok/s):          14455.72  
---------------Time to First Token----------------
Mean TTFT (ms):                          3578.91   
Median TTFT (ms):                        3357.48   
P99 TTFT (ms):                           7000.44   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          102.12    
Median TPOT (ms):                        63.15     
P99 TPOT (ms):                           252.93    
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.75     
Median ITL (ms):                         41.04     
P99 ITL (ms):                            259.23    
==================================================

ROCmAttentionBackend @ main [repetition 3]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  28.38     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              35.24     
Output token throughput (tok/s):         6988.73   
Total Token throughput (tok/s):          14571.29  
---------------Time to First Token----------------
Mean TTFT (ms):                          3506.50   
Median TTFT (ms):                        3352.46   
P99 TTFT (ms):                           6795.89   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          99.63     
Median TPOT (ms):                        62.60     
P99 TPOT (ms):                           243.01    
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.15     
Median ITL (ms):                         41.20     
P99 ITL (ms):                            249.40    
==================================================

ROCmAttentionBackend @ tpa-gqa-opt [repetition 1]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.36     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              44.72     
Output token throughput (tok/s):         8869.87   
Total Token throughput (tok/s):          18493.40  
---------------Time to First Token----------------
Mean TTFT (ms):                          3506.37   
Median TTFT (ms):                        3338.07   
P99 TTFT (ms):                           6751.05   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          93.56     
Median TPOT (ms):                        55.73     
P99 TPOT (ms):                           237.69    
---------------Inter-token Latency----------------
Mean ITL (ms):                           41.86     
Median ITL (ms):                         23.26     
P99 ITL (ms):                            243.69    
==================================================

ROCmAttentionBackend @ tpa-gqa-opt [repetition 2]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.29     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              44.86     
Output token throughput (tok/s):         8897.57   
Total Token throughput (tok/s):          18551.15  
---------------Time to First Token----------------
Mean TTFT (ms):                          3491.51   
Median TTFT (ms):                        3289.22   
P99 TTFT (ms):                           6723.10   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          91.50     
Median TPOT (ms):                        51.03     
P99 TPOT (ms):                           238.13    
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.98     
Median ITL (ms):                         24.54     
P99 ITL (ms):                            243.43    
==================================================

ROCmAttentionBackend @ tpa-gqa-opt [repetition 3]

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.33     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              44.79     
Output token throughput (tok/s):         8883.44   
Total Token throughput (tok/s):          18521.70  
---------------Time to First Token----------------
Mean TTFT (ms):                          3536.76   
Median TTFT (ms):                        3383.08   
P99 TTFT (ms):                           6583.68   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          91.64     
Median TPOT (ms):                        52.71     
P99 TPOT (ms):                           233.16    
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.30     
Median ITL (ms):                         24.87     
P99 ITL (ms):                            238.99    
==================================================

I think it looks OK.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep thanks for rerunning those tests!

Changes look good and the performance optimization makes sense

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 13, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 13, 2025 20:27
@tdoublep
Copy link
Member Author

@tlrmchlsmth The multi-modal test that is failing does not look related to these changes.

@vllm-bot vllm-bot merged commit fb4c7f8 into vllm-project:main Mar 14, 2025
54 of 56 checks passed
@tdoublep tdoublep deleted the tpa-gqa-opt branch March 14, 2025 10:43
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
…r handle GQA. (vllm-project#14431)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Signed-off-by: Richard Liu <ricliu@google.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…r handle GQA. (vllm-project#14431)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…r handle GQA. (vllm-project#14431)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…r handle GQA. (vllm-project#14431)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants