KEMBAR78
[None][feat] support attention dp for qwen3 dense model by Nekofish-L · Pull Request #7618 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@Nekofish-L
Copy link
Contributor

@Nekofish-L Nekofish-L commented Sep 8, 2025

Summary by CodeRabbit

  • New Features
    • Enhanced distributed training for Qwen3 with optional attention data parallelism.
    • Per-layer control of cross-GPU synchronization to reduce unnecessary communication.
    • Improved scalability by propagating cross-rank token metadata through the MLP path.
    • Better latency characteristics via updated execution settings for mixed attention/MLP workloads.

Description

This PR introduces modifications to the Qwen3 Dense model to efficiently support dp attention, enabling near-linear performance scaling as more GPUs are added within a dp attention setup.

Decoding performance(TPOT/ms)

  • model: Qwen3-32B-FP8
  • device: H20
batch_size_per_gpu 64
TP1 32.53
TP8 43.30
TP8+attention dp 32.42

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
@Nekofish-L Nekofish-L requested review from a team as code owners September 8, 2025 12:02
@Nekofish-L Nekofish-L requested review from 2ez4bz and byshiue September 8, 2025 12:02
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 8, 2025

📝 Walkthrough

Walkthrough

Introduces per-layer all-reduce control in Qwen3DecoderLayer by adding AllReduceParams and attention data-parallel flags. Forward now threads all-reduce enablement into self-attention and MLP, passes distributed token-count metadata to MLP, and adjusts GatedMLP tensor-parallel override based on attention DP.

Changes

Cohort / File(s) Change Summary
Qwen3DecoderLayer distributed all-reduce threading
tensorrt_llm/_torch/models/modeling_qwen3.py
Import AllReduceParams; derive mapping and enable_attention_dp from model_config; set disable_allreduce when tp_size==1 or attention DP enabled; pass all_reduce_params to self-attention; pass all_rank_num_tokens/all_rank_max_num_tokens and final_all_reduce_params plus cutlass_min_latency_mode to MLP; set GatedMLP.overridden_tp_size=1 when attention DP enabled.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Caller
  participant L as Qwen3DecoderLayer
  participant A as SelfAttention
  participant M as MLP

  Note over L: Compute enable_attention_dp, mapping, disable_allreduce
  C->>L: forward(x, attn_metadata, ...)
  L->>A: forward(..., all_reduce_params=AllReduceParams(enable=!disable_allreduce))
  A-->>L: attn_out
  Note over L: Extract all_rank_num_tokens, all_rank_max_num_tokens
  L->>M: forward(...,<br/>all_rank_num_tokens,<br/>all_rank_max_num_tokens,<br/>final_all_reduce_params=AllReduceParams(enable=!disable_allreduce),<br/>cutlass_min_latency_mode=false)
  M-->>L: mlp_out
  L-->>C: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • byshiue
  • nv-yilinf
  • mikeiovine
  • yizhang-nv
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_qwen3.py (1)

1-1: Add NVIDIA Apache-2.0 copyright header (2025).

Per repo guidelines, prepend the standard NVIDIA Apache-2.0 header to this Python file.

Here’s a header you can paste at the very top (outside diffs since it’s outside the changed hunk):

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_qwen3.py (4)

84-86: Guard mapping attribute for backward compatibility.

Older ModelConfig.mapping may not have enable_attention_dp. Use getattr with a default.

-        self.mapping = model_config.mapping
-        self.enable_attention_dp = self.mapping.enable_attention_dp
+        self.mapping = model_config.mapping
+        self.enable_attention_dp = getattr(self.mapping, "enable_attention_dp", False)

137-144: Prune unsupported/unused kwargs in MLP call.

  • GatedMLP.forward doesn’t take all_rank_max_num_tokens (ignored via **kwargs).
  • cutlass_min_latency_mode is not consumed by GatedMLP.forward or forwarded to down_proj in current code.
         self.post_attention_layernorm(
             hidden_states, residual)
         hidden_states = self.mlp(
             hidden_states,
             all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
-            all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
-            final_all_reduce_params=AllReduceParams(
-                enable_allreduce=not self.disable_allreduce),
-            cutlass_min_latency_mode=False,
+            final_all_reduce_params=AllReduceParams(
+                enable_allreduce=not self.disable_mlp_allreduce),
         )

Please confirm if any downstream expects these args; if so, wire them through GatedMLP.forward and into down_proj explicitly.


124-145: Numerical-correctness and comms plan for attention-DP need a brief docstring.

Given the changes alter per-layer reduction behavior and MLP TP, add a short class-level or method docstring explaining:

  • When attention-DP is enabled, which modules change TP/AR behavior and why.
  • What reductions are skipped vs moved, and any constraints on mapping.tp_size.

I can draft a concise docstring if helpful.


168-172: Sanity tests for attention-DP path.

Please add unit/integration tests that:

  • Validate numerics parity vs baseline at tp_size=1 with attention-DP on/off.
  • Check shapes and absence of partial outputs when tp_size>1 and attention-DP enabled.
  • Exercise batched decode to cover all_rank_num_tokens propagation.

I can provide test scaffolding for Torch path with randomized seeds and small configs.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a1e138 and 37685a7.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_qwen3.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/modeling_qwen3.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/modeling_qwen3.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/modeling_qwen3.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_qwen3.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_qwen3.py (3)
tensorrt_llm/functional.py (1)
  • AllReduceParams (3900-3939)
tensorrt_llm/_torch/modules/gated_mlp.py (1)
  • GatedMLP (19-175)
tensorrt_llm/_torch/attention_backend/interface.py (2)
  • all_rank_num_tokens (171-172)
  • all_rank_num_tokens (175-178)
🔇 Additional comments (1)
tensorrt_llm/_torch/models/modeling_qwen3.py (1)

92-92: Overriding MLP TP to 1 is a big behavior change; verify perf/memory intent.

Setting overridden_tp_size=1 disables TP inside GatedMLP. This can degrade throughput and increase per-rank memory. If the sole goal is to avoid cross-rank all-reduce at MLP output, keep TP and control reduction via final_all_reduce_params or smaller TP groups rather than forcing TP=1.

Would you share profiling showing TP8 vs TP1 MLP under attention-DP for Qwen3 Dense? If the override is required, document why.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Sep 8, 2025
@byshiue
Copy link
Collaborator

byshiue commented Sep 10, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18269 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18269 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13693 completed with status: 'FAILURE'

@byshiue
Copy link
Collaborator

byshiue commented Sep 10, 2025

Can you run pre-commit run -a to fix the code format issue? The CI error is

[2025-09-10T00:58:16.726Z] isort....................................................................Failed

[2025-09-10T00:58:16.726Z] - hook id: isort

[2025-09-10T00:58:16.726Z] - files were modified by this hook

[2025-09-10T00:58:16.726Z] 

[2025-09-10T00:58:16.726Z] Fixing /home/jenkins/agent/workspace/LLM/main/L0_MergeRequest_PR/llm/tensorrt_llm/_torch/models/modeling_qwen3.py

[2025-09-10T00:58:16.726Z] Skipped 178 files

[2025-09-10T00:58:16.726Z] 

[2025-09-10T00:58:16.726Z] CRLF end-lines remover...................................................Passed

[2025-09-10T00:58:16.726Z] yapf.....................................................................Failed

[2025-09-10T00:58:16.726Z] - hook id: yapf

[2025-09-10T00:58:16.726Z] - files were modified by this hook

[2025-09-10T00:58:16.726Z] check for added large files..............................................Passed

[2025-09-10T00:58:16.726Z] check for merge conflicts................................................Passed

[2025-09-10T00:58:16.726Z] check for broken symlinks............................(no files to check)Skipped

[2025-09-10T00:58:16.726Z] detect private key.......................................................Passed

[2025-09-10T00:58:16.726Z] fix end of files.........................................................Passed

[2025-09-10T00:58:16.726Z] check yaml...............................................................Passed

[2025-09-10T00:58:16.726Z] trim trailing whitespace.................................................Passed

[2025-09-10T00:58:16.726Z] check toml...............................................................Passed

[2025-09-10T00:58:16.726Z] mixed line ending........................................................Passed

[2025-09-10T00:58:16.726Z] debug statements (python)................................................Passed

[2025-09-10T00:58:16.726Z] check json...........................................(no files to check)Skipped

[2025-09-10T00:58:16.726Z] autoflake................................................................Passed

[2025-09-10T00:58:16.726Z] clang-format.............................................................Passed

[2025-09-10T00:58:16.726Z] cmake-format.............................................................Passed

[2025-09-10T00:58:16.726Z] codespell................................................................Passed

[2025-09-10T00:58:16.726Z] ruff.....................................................................Passed

[2025-09-10T00:58:16.726Z] ruff-format..............................................................Passed

[2025-09-10T00:58:16.726Z] mdformat.................................................................Passed

[2025-09-10T00:58:16.726Z] pre-commit hook(s) made changes.

[2025-09-10T00:58:16.726Z] If you are seeing this message in CI, reproduce locally with: `pre-commit run --all-files`.

[2025-09-10T00:58:16.726Z] To run `pre-commit` as part of git workflow, use `pre-commit install`.

[2025-09-10T00:58:16.726Z] All changes made by hooks:

[2025-09-10T00:58:16.726Z] diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py

[2025-09-10T00:58:16.726Z] index b95f3fc..84bc2f8 100644

[2025-09-10T00:58:16.726Z] --- a/tensorrt_llm/_torch/models/modeling_qwen3.py

[2025-09-10T00:58:16.726Z] +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py

[2025-09-10T00:58:16.726Z] @@ -7,8 +7,8 @@ from transformers import Qwen3Config

[2025-09-10T00:58:16.726Z]  from tensorrt_llm.functional import PositionEmbeddingType

[2025-09-10T00:58:16.726Z]  

[2025-09-10T00:58:16.726Z]  from ..attention_backend import AttentionMetadata

[2025-09-10T00:58:16.726Z] -from ..distributed import AllReduceParams

[2025-09-10T00:58:16.726Z]  from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams

[2025-09-10T00:58:16.726Z] +from ..distributed import AllReduceParams

[2025-09-10T00:58:16.726Z]  from ..model_config import ModelConfig

[2025-09-10T00:58:16.726Z]  from ..modules.decoder_layer import DecoderLayer

[2025-09-10T00:58:16.726Z]  from ..modules.embedding import Embedding

[2025-09-10T00:58:16.726Z] @@ -102,7 +102,6 @@ class Qwen3DecoderLayer(DecoderLayer):

[2025-09-10T00:58:16.726Z]          self.disable_allreduce = (self.mapping.tp_size == 1

[2025-09-10T00:58:16.726Z]                                    or self.enable_attention_dp)

[2025-09-10T00:58:16.726Z]  

[2025-09-10T00:58:16.726Z] -

[2025-09-10T00:58:16.726Z]      def forward(

[2025-09-10T00:58:16.726Z]          self,

[2025-09-10T00:58:16.726Z]          position_ids: torch.IntTensor,

[2025-09-10T00:58:16.726Z] 

[2025-09-10T00:58:16.726Z] 

[2025-09-10T00:58:16.726Z] Error: pre-commit checks failed

[2025-09-10T00:58:16.726Z] Please refer to our coding style guidelines at: https://github.com/NVIDIA/TensorRT-LLM/blob/main/CONTRIBUTING.md#coding-style to fix this issue

[2025-09-10T00:58:16.726Z] + git restore .

[2025-09-10T00:58:16.726Z] + false

Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
@Nekofish-L
Copy link
Contributor Author

Hi @byshiue , Thank you for the review!
I have successfully executed pre-commit run -a locally, and all tests have passed.
image

@byshiue
Copy link
Collaborator

byshiue commented Sep 15, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18536 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18536 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13913 completed with status: 'SUCCESS'

@byshiue byshiue enabled auto-merge (squash) September 16, 2025 01:33
@byshiue byshiue merged commit 96f11b1 into NVIDIA:main Sep 16, 2025
5 checks passed
byshiue added a commit that referenced this pull request Sep 16, 2025
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants