KEMBAR78
[TRTLLM-6342][bug] Fix shape propagation after TP sharding by greg-kwasniewski1 · Pull Request #7912 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@greg-kwasniewski1
Copy link
Collaborator

@greg-kwasniewski1 greg-kwasniewski1 commented Sep 22, 2025

When doing column-row shard on attention modules (resulting in head parallelism), some models (e..g., Nemotron explicitly sets the num_heads in the view node after q, k, v projections. This results in shape mismatch in the later all-reduce after O projection and a crash.

The correct way (one of the correct ways), is to keep it implicit (-1), e.g., similarly to Llama3.

This PR sets implicit num_heads value for views after column sharding.

Summary by CodeRabbit

  • Bug Fixes
    • Corrected tensor shape handling after column-sharding of linear layers, preventing invalid 4D reshapes that could cause dimension mismatches or runtime crashes in non-distributed deployments.
    • View operations now adapt dynamically to the sharded output, improving robustness across varying batch and sequence sizes.
    • Enhances stability and reliability of inference and auto-deploy workflows without requiring configuration changes.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@greg-kwasniewski1 greg-kwasniewski1 self-assigned this Sep 22, 2025
@greg-kwasniewski1 greg-kwasniewski1 changed the title [TRTLLM-6342][feat] Fix shape propagation after TP sharding [TRTLLM-6342][bug] Fix shape propagation after TP sharding Sep 22, 2025
@greg-kwasniewski1 greg-kwasniewski1 added bug Something isn't working AutoDeploy <NV> AutoDeploy Backend labels Sep 22, 2025
@github-project-automation github-project-automation bot moved this from Backlog to In review in AutoDeploy Board Sep 22, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 22, 2025

📝 Walkthrough

Walkthrough

Introduces a helper to adjust aten.view shapes after column sharding and wires it into the non-distributed branch of _insert_sharded_matmul. The helper scans a node’s users for view ops and normalizes the third dimension in their shape argument to -1 when needed. No public APIs changed.

Changes

Cohort / File(s) Summary
Sharding utilities
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Added _update_view_nodes(node) to detect/fix hardcoded 4D view shapes post column-shard; invoked from _insert_sharded_matmul when add_dist=False to update downstream aten.view shapes; added logging of updated args; no signature changes.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller as _insert_sharded_matmul
  participant Node as Node (matmul result)
  participant Utils as _update_view_nodes
  participant Views as aten.view users

  Caller->>Node: produce sharded output (non-distributed path)
  Note right of Caller: add_dist == False
  Caller->>Utils: _update_view_nodes(Node)
  Utils->>Views: scan users for aten.view
  alt view shape has fixed 3rd dim
    Utils->>Views: rewrite shape: dims[2] = -1
    Utils-->>Caller: log updated view args
  else no change needed
    Utils-->>Caller: return
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "[TRTLLM-6342][bug] Fix shape propagation after TP sharding" succinctly and accurately summarizes the primary change in this PR, which corrects view/shape propagation after tensor-parallel (TP) column/row sharding. It includes a ticket reference and a type tag and avoids noise or file lists, so a teammate scanning history will understand the main intent. Therefore the title meets the repository expectations for clarity and conciseness.
Description Check ✅ Passed The PR description clearly explains the bug (explicit num_heads in view nodes causing shape mismatch after column-row TP sharding) and the implemented fix (use implicit num_heads = -1 for views after column sharding), with concrete examples and links that make intent and scope easy to review. However, the repository-required PR title header is missing and the template's "## Description" and "## Test Coverage" sections are left empty. Because the core rationale and change are well documented, the description is mostly complete for review purposes.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

1-1: Add required NVIDIA Apache-2.0 copyright header (2025).

This file is missing the mandated header at the very top per repo guidelines.

Apply this diff:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 """Sharding config definitions for the inference optimizer."""
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

57-74: Make the view-shape rewrite robust to reshape variants and kwarg forms.

Covers aten.reshape in addition to aten.view; safely handle kwargs and non-list/tuple shapes; drop stray newline in log.

Apply this diff:

-def _update_view_nodes(node: Node) -> None:
+def _update_view_nodes(node: Node) -> None:
     """
     After sharding weights of the linear node, using column split
     in attention module (Q, K, V),
     the output Y = X @ W^T is [batch, seq, num_heads // TP_size, head_dim]
     Some models hardcode the shape of the output to be [batch, seq, num_heads, head_dim]
     instead of implicit [batch, seq, -1, head_dim].
     Detect such cases and update the shape of the view node accordingly.
     """
-    view_nodes = [n for n in node.users if is_op(n, torch.ops.aten.view)]
+    view_like_ops = (torch.ops.aten.view, torch.ops.aten.reshape)
+    view_nodes = [n for n in node.users if is_op(n, view_like_ops)]
     for view_node in view_nodes:
-        view_shape = view_node.args[1]
-        if len(view_shape) == 4 and view_shape[2] != -1:
-            args = list(view_node.args)
-            args[1] = [view_shape[0], view_shape[1], -1, view_shape[3]]
-            view_node.args = tuple(args)
-            ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}")
+        # Extract shape from args or kwargs
+        args = list(view_node.args)
+        kwargs = dict(view_node.kwargs)
+        view_shape = None
+        if len(args) > 1:
+            view_shape = args[1]
+        else:
+            # aten.view uses "size", aten.reshape uses "shape"
+            view_shape = kwargs.get("size") or kwargs.get("shape")
+
+        # Only handle static containers (may contain dynamic Node entries)
+        if isinstance(view_shape, (list, tuple, torch.Size)) and len(view_shape) == 4 and view_shape[2] != -1:
+            new_shape = list(view_shape)
+            new_shape[2] = -1
+            if len(args) > 1:
+                args[1] = new_shape
+            else:
+                key = "shape" if is_op(view_node, torch.ops.aten.reshape) else "size"
+                kwargs[key] = new_shape
+            view_node.args = tuple(args)
+            view_node.kwargs = kwargs
+            ad_logger.debug(f"Updated view-like node {view_node} arguments to {new_shape}")

57-74: Optional: Handle pass-through ops between matmul and view (contiguous/clone).

If a no-op aliasing op sits between the sharded matmul and the view/reshape, the current direct-user scan will miss it. Consider a shallow BFS through alias-preserving ops to reach view/reshape nodes.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b5391b4 and 248f5f2.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (179-202)
tensorrt_llm/logger.py (1)
  • debug (144-145)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

179-183: Call site placement looks correct for colwise-without-gather.

Early return after updating view/reshape nodes ensures sharded outputs keep implicit heads, avoiding post-O all-reduce mismatch. No change requested.

Please validate on:

  • Nemotron model with column-row TP (original reproducer).
  • A Llama3 path where views already use -1 (should be a no-op).

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19615 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19615 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14754 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@greg-kwasniewski1
Copy link
Collaborator Author

@lucaslie can you merge and close please?

@lucaslie lucaslie merged commit 6fd2258 into NVIDIA:main Oct 1, 2025
9 of 13 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in AutoDeploy Board Oct 1, 2025
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 2, 2025
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Faradawn Yang <faradawny@gmail.com>
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull request Oct 3, 2025
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 3, 2025
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Faradawn Yang <faradawny@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AutoDeploy <NV> AutoDeploy Backend bug Something isn't working

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants