KEMBAR78
[TRTLLM-6342][feat] TP Sharding read from the model config by greg-kwasniewski1 · Pull Request #6972 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@greg-kwasniewski1
Copy link
Collaborator

@greg-kwasniewski1 greg-kwasniewski1 commented Aug 18, 2025

Description

If base_model_tp_plan is present in the model config and ad_config.use_sharding_from_config == True, skip sharing pattern detection, and instead, apply the sharding from the config.

Test Coverage

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py has been updated to test new sharding logic.

Summary by CodeRabbit

  • New Features

    • Per-model/factory sharding configuration and public accessor; new deploy options to choose factory sharding and which sharding dimensions to run; shape-propagation after sharding enabled.
  • Refactor

    • Consolidated sharding detection into a single, modular flow with a dedicated executor and pluggable TP/EP/DP entrypoints; supports factory-provided plans.
  • Behavior

    • Standardized split-dimension semantics and default simple-column sharding behavior.
  • Tests

    • Updated tests for factory vs default paths, renamed entrypoint usage, added expert-count skip, and adjusted expectations.
  • Other

    • Node filtering now accepts predicate-based filters.

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@greg-kwasniewski1 greg-kwasniewski1 requested a review from a team as a code owner August 18, 2025 01:38
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 18, 2025

📝 Walkthrough

Walkthrough

Consolidates sharding detection into a single detect_sharding transform with modular TP/EP/DP detectors and optional factory-driven plans; adds sharding config plumbing to ModelFactory/HF factories; extends sharding and node utilities; updates default config and multigpu sharding tests.

Changes

Cohort / File(s) Summary
Config update
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Rename transform key detect_column_row_sharddetect_sharding; remove explicit detect_ep_shard/detect_dp_bmm_shard; add use_sharding_from_factory: false and sharding_dims: ['tp','ep','dp'] under detect_sharding; enable run_shape_prop: true for sharding_transform_executor.
Model factory
tensorrt_llm/_torch/auto_deploy/models/factory.py
Add ShardingConfigSource(Enum) (HUGGINGFACE, UNKNOWN); initialize internal _sharding_config and add public get_sharding_config() -> Dict.
HF model factories
tensorrt_llm/_torch/auto_deploy/models/hf.py
Import ShardingConfigSource; set factory source to HUGGINGFACE in HF factories; call _set_sharding_config(model.config) during model build; add _set_sharding_config handlers (with text_config overrides) to populate head_dim, tp_plan, and num_hidden_layers.
Modular sharding framework
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Replace monolithic ColumnRowShard with Sharding/ShardingTransformConfig; change registry to detect_sharding; add ShardingTransformExecutor and modular detectors (detect_sharding_from_factory_config, detect_column_row_shard, detect_dp_bmm_shard, detect_ep_shard); support use_sharding_from_factory and sharding_dims; aggregate TP/EP/DP results and add regex-based TP plan matching and head_dim-aware local-shape logic.
Sharding utilities & semantics
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Expand ShardingConfig with factory_source, rank, world_size, predefined_config, simple_shard_only, use_sharding_from_factory, sharding_dims; add _validate_and_normalize(), validate_config(), get_predefined_config(); flip SplitDimension mapping to COLUMN=0, ROW=1; adjust TPShardingInfo.validate(); add helper _append_simple_shard.
Node utilities
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Change filtered_nodes signature to accept either a callable predicate (target) or op(s) (ops); implement branching to support predicate-based filtering while preserving op-based usage.
LLM args / CLI config
tensorrt_llm/_torch/auto_deploy/llm_args.py
Add use_sharding_from_factory: bool = False and sharding_dims: List[str] = ["tp","ep","dp"] to AutoDeployConfig.
Tests: multigpu sharding
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/*
.../test_bmm_sharding.py, .../test_ep_sharding.py, .../test_tp_sharding.py
Update tests to use detect_sharding with use_sharding_from_factory: False; TP tests add base_model_tp_plan/predefined_config, parameterize from_config paths and update expected split_dim semantics; EP tests add early skip when world_size > num_experts and adjust gate parameter computation; BMM/EP config keys and usages aligned to new API.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant IO as InferenceOptimizer
  participant DS as detect_sharding
  participant MF as ModelFactory
  participant Det as Detectors (TP/EP/DP)
  participant EX as sharding_transform_executor
  participant GM as GraphModule

  IO->>DS: apply(config, GM)
  alt use_sharding_from_factory && predefined_config present
    DS->>MF: get_sharding_config()
    DS->>GM: detect_sharding_from_factory_config(GM, sharding_config)
    DS-->>IO: TransformInfo (factory-derived)
  else heuristics path
    DS->>Det: run TP / EP / DP detectors (filtered by sharding_dims)
    Det-->>DS: TransformInfos
    DS-->>IO: Aggregated TransformInfo
  end
  IO->>EX: apply(GM, TransformInfo)
  alt run_shape_prop enabled
    EX->>GM: run shape propagation
  end
  EX-->>IO: transformed GraphModule
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • suyoggupta
  • pcastonguay
  • schetlur-nv

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

54-56: ShardingTransformExecutor doesn’t consume run_shape_prop

You’ve enabled run_shape_prop: true for the sharding_transform_executor in default.yaml, but ShardingTransformExecutor._apply never reads self.config.run_shape_prop (no shape-propagation is performed). Please update the executor to respect this flag or remove it from the config:

File to update
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
– In _apply(...), add a conditional shape-propagation step when self.config.run_shape_prop is true (e.g., call canonicalize_graph(gm, shape_prop=True) around the transformation loop).
Config cleanup
If shape propagation isn’t supported yet, remove run_shape_prop: true from the sharding_transform_executor section in default.yaml to avoid confusion.

🧹 Nitpick comments (17)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

210-213: Tighten type hints to avoid ambiguity between ops and predicates

OperatorLike currently includes Callable, and target repeats Callable, which is confusing for type-checkers and readers. Consider introducing a dedicated NodePredicate alias and narrowing OperatorLike to only op types.

Proposed refinement outside the selected lines (illustrative):

# near lines ~28-31
from typing import Callable
NodePredicate = Callable[[Node], bool]
OpOrOverload = Union[OpOverloadPacket, OpOverload]
OperatorLike = OpOrOverload  # remove Callable from here

# function signature (lines 210-213)
def filtered_nodes(
    nodes: Iterable[Node],
    target: Union[NodePredicate, Union[OperatorLike, Iterable[OperatorLike]]] = None,
    ops: Union[OperatorLike, Iterable[OperatorLike]] = None,
) -> Iterable[Node]:
    ...

241-251: Guard against silent no-op when both target and ops are None

If both target and ops are None, filtered_nodes silently yields nothing. Being explicit will save debugging time.

Apply this diff to validate inputs:

-    else:
-        # Handle the case where target or ops contains operations
-        operations = ops if ops is not None else target
-        for node in nodes:
-            if is_op(node, operations):
-                yield node
+    else:
+        # Handle the case where target or ops contains operations
+        operations = ops if ops is not None else target
+        if operations is None:
+            raise ValueError("filtered_nodes requires either a predicate (target) or ops to be set.")
+        for node in nodes:
+            if is_op(node, operations):
+                yield node
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3)

67-70: Explicit False is redundant

use_sharding_from_factory defaults to False in default.yaml. Keeping it explicit is fine for clarity, but can be omitted to reduce verbosity.

-            "detect_sharding": {
-                "stage": "sharding",
-                "use_sharding_from_factory": False,
-            },
+            "detect_sharding": {
+                "stage": "sharding",
+            },

128-131: Same minor nit: explicit False can be omitted

To reduce noise in tests, you can rely on the default.

-            "detect_sharding": {
-                "stage": "sharding",
-                "use_sharding_from_factory": False,
-            },
+            "detect_sharding": {
+                "stage": "sharding",
+            },

67-70: Consider adding a test for factory-driven sharding

Since this PR introduces reading sharding plans from a model factory/config, consider adding a focused test where use_sharding_from_factory=True and a minimal FakeFactory returns a BMM plan. It will protect against regressions in the new path.

I can draft a small test fixture with a FakeFactory exposing get_sharding_config and get_sharding_config_source to exercise this path. Want me to open a follow-up?

tensorrt_llm/_torch/auto_deploy/models/factory.py (2)

49-49: Provide a typed, validated setter for sharding config

Right now factories must mutate _sharding_config directly. Expose a small helper to set it with optional validation so different factories follow a consistent contract.

Outside the selected lines, add something like:

def _set_sharding_config(self, cfg: Dict[str, Any]) -> None:
    # Optionally validate minimal keys, e.g., "tp_plan" when available
    self._sharding_config = copy.deepcopy(cfg or {})

120-127: Docstring clarity: this returns the sharding config source

Nit: “source of the model factory” reads oddly. Clarify it’s the source of the sharding config.

-    def get_sharding_config_source(self) -> ShardingConfigSource:
-        """Return the source of the model factory.
-
-        Returns:
-            The source identifier for this model factory.
-        """
+    def get_sharding_config_source(self) -> ShardingConfigSource:
+        """Return the source of the sharding configuration.
+
+        Returns:
+            The origin identifier for the sharding config provided by this factory.
+        """
         return ShardingConfigSource.UNKNOWN
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)

26-29: Validate world_size > num_experts in the sharding implementation

The current EP-sharding utility in
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
does not guard against world_size > num_experts. In that case

experts_per_rank = num_experts // world_size

yields zero, leading to empty partitions or downstream errors. I recommend adding an early check at the top of _insert_sharded_moe:

def _insert_sharded_moe(gm, node, *args):
    # args = (selected_experts, final_scales, ..., world_size, rank, num_experts, ...)
    if world_size > num_experts:
        raise ValueError(
            f"world_size ({world_size}) cannot exceed num_experts ({num_experts})"
        )
    # existing logic…

— Location to update:
• tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
• Function: _insert_sharded_moe

tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

174-183: Validate model_config attributes before accessing

The method assumes that the HuggingFace model config may not have the expected attributes. This is handled correctly with hasattr() checks, but consider adding debug logging when expected attributes are missing to aid troubleshooting.

 def _set_sharding_config(self, model_config: PretrainedConfig):
     """Set the sharding config for the model."""
     self._sharding_config["head_dim"] = 1
     if hasattr(model_config, "base_model_tp_plan"):
         self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
+    else:
+        ad_logger.debug("base_model_tp_plan not found in model config")
     if hasattr(model_config, "head_dim"):
         self._sharding_config["head_dim"] = model_config.head_dim
     if hasattr(model_config, "num_hidden_layers"):
         self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)

26-49: Document the base_model_tp_plan structure and usage

The base_model_tp_plan dictionary defines sharding strategies but lacks documentation about what each value means and why some entries are commented out.

 base_model_tp_plan = {
+    # Tensor parallel sharding plan mapping module names to sharding strategies:
+    # - "colwise": Split along columns (dimension 0) without communication
+    # - "rowwise": Split along rows (dimension 1) with all-reduce
+    # - "gather": Simple shard with row split and all-gather
     "q_proj": "colwise",
     "k_proj": "colwise",
     "v_proj": "colwise",
     "o_proj": "rowwise",
     "gate_proj": "colwise",
     "up_proj": "colwise",
     "down_proj": "rowwise",
     "linear1": "colwise",
     "linear2": "rowwise",
     "linear": "gather",
-    # "input_layernorm.weight": "sequence_parallel",
-    # "post_attention_layernorm.weight": "sequence_parallel",
-    # "norm.weight": "sequence_parallel",
-    # "shared_expert.gate_proj": "local_colwise",
-    # "shared_expert.up_proj": "local_colwise",
-    # "shared_expert.down_proj": "local_rowwise",
-    # "experts.gate_up_proj": "local_packed_rowwise",
-    # "experts.down_proj": "local_colwise",
-    # "experts": "local",
+    # TODO: The following strategies require hybrid EP+TP and/or SP support:
+    # "input_layernorm.weight": "sequence_parallel",
+    # "post_attention_layernorm.weight": "sequence_parallel",
+    # "norm.weight": "sequence_parallel",
+    # "shared_expert.gate_proj": "local_colwise",
+    # "shared_expert.up_proj": "local_colwise",
+    # "shared_expert.down_proj": "local_rowwise",
+    # "experts.gate_up_proj": "local_packed_rowwise",
+    # "experts.down_proj": "local_colwise",
+    # "experts": "local",
     "feed_forward": "gather",
     "self": "gather",
     "weight": "gather",
 }

322-323: Consider using proper logging instead of print statements

Debug print statements should be replaced with proper logging for better control over output in production environments.

-    print(f"detected_transformations: {detected_transformations}")
-    print(f"expected_transformations: {expected_transformations}")
+    ad_logger.debug(f"detected_transformations: {detected_transformations}")
+    ad_logger.debug(f"expected_transformations: {expected_transformations}")

Import ad_logger at the top of the file:

from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger

378-379: Remove or properly integrate the main block

The __main__ block appears to be debug code that should either be removed or converted to a proper test case.

-if __name__ == "__main__":
-    _run_pattern_detection_job(nn.Linear, False, 0, 8, False)
+# Note: To run pattern detection manually for debugging:
+# pytest test_tp_sharding.py::test_sharding_pattern_detection -k "Linear" --verbose
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

186-191: Critical comment about dimension mapping needs clarification

The comment explains that COLUMN/ROW names reflect HuggingFace notation but the actual split is counterintuitive (COLUMN splits dimension 0). This is a critical detail that affects correctness. Consider adding an example to make this clearer.

     # NOTE: The names COLUMN/ROW reflect the hugging face
     # base_tp_plan sharding notation, but since we assume Y = W @ X^T,
     # when splitting weight matrix W^T across columns, the actual split
     # is over dimension 0
+    # Example: For weight W of shape [out_features, in_features]:
+    #   - COLUMN (0): splits out_features dimension
+    #   - ROW (1): splits in_features dimension
     COLUMN = 0
     ROW = 1

519-527: Improve error message for unsupported factory sources

The error message has a typo and could be more informative about which sources are supported.

     if self.factory_source != ShardingConfigSource.HUGGINGFACE:
         ad_logger.warning(
-            "Sharding config is is currently only " + "supported for HuggingFace. Skipping."
+            f"Sharding config from {self.factory_source.value} is currently only supported for HuggingFace sources. Skipping."
         )

547-564: Consider making allowed sharding values a class constant

The allowed values for tp_plan should be defined as a class constant for better maintainability and reusability.

+    ALLOWED_TP_PLAN_VALUES = {
+        "colwise",  # row split and no collective
+        "rowwise",  # column split and all-reduce
+        "gather",  # simple shard (row + all_gather)
+        # TODO: remaining values are not supported yet.
+        # They require hybrid EP+TP and/or SP support.
+        # "sequence_parallel", # sequence parallelism
+        # "local_colwise",
+        # "local_rowwise",
+        # "local_packed_rowwise",
+        # "local",
+    }
+
     def validate_config(self) -> bool:
         # ... existing validation code ...
         values = set(tp_plan.values())
-        allowed_values = {
-            "colwise",  # row split and no collective
-            "rowwise",  # column split and all-reduce
-            "gather",  # simple shard (row + all_gather)
-            # TODO: remaining values are not supported yet.
-            # They require hybrid EP+TP and/or SP support.
-            # "sequence_parallel", # sequence parallelism
-            # "local_colwise",
-            # "local_rowwise",
-            # "local_packed_rowwise",
-            # "local",
-        }
-        if not values.issubset(allowed_values):
+        if not values.issubset(self.ALLOWED_TP_PLAN_VALUES):
             ad_logger.warning("Sharding config contains invalid values. Skipping.")
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

273-280: Improve regex pattern construction for better readability

The regex pattern construction using @ as a placeholder is clever but could be more readable with a direct approach.

         # use regex to find if module_name matches any of the keys in sharding_config
         for key in tp_plan.keys():
-            pattern_string = "*" + key + "*"
-            # convert it to regex. Escape dots, replace * with .*
-            # First, we substitute * with an unlikely character, e.g. @
-            # Then we escape dots, and finally we replace @ with .*
-            pattern_string = pattern_string.replace("*", "@")
-            pattern_regex = re.escape(pattern_string).replace("@", ".*")
+            # Convert glob pattern to regex: * becomes .*, other chars are escaped
+            pattern_string = f"*{key}*"
+            # Replace * with placeholder before escaping
+            parts = pattern_string.split("*")
+            pattern_regex = ".*".join(re.escape(part) for part in parts)
             if re.match(pattern_regex, module_name):

306-311: Consolidate TODO comments for unsupported features

The warning messages for unsupported features could be more informative by including what configuration value was encountered.

                 elif "sequence" in config:
                     # TODO: Sequence parallelism is not supported yet.
-                    ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
+                    ad_logger.warning(f"Sequence parallelism config '{config}' for '{module_name}' is not supported yet. Skipping.")
                 elif "local" in config:
                     # TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
-                    ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
+                    ad_logger.warning(f"Local EP+TP sharding config '{config}' for '{module_name}' is not supported yet. Skipping.")
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ce0b13e and 6a61d1c.

📒 Files selected for processing (10)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/factory.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (12 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/factory.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/factory.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
🧬 Code Graph Analysis (7)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (2)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
  • SharedConfig (51-56)
tensorrt_llm/mapping.py (1)
  • local_rank (399-400)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (1)
  • target (502-503)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
  • get_sharding_config_source (203-209)
tensorrt_llm/_torch/auto_deploy/models/hf.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (4)
  • ModelFactory (23-228)
  • ShardingConfigSource (16-20)
  • model (52-54)
  • get_sharding_config_source (120-126)
tensorrt_llm/models/modeling_utils.py (1)
  • PretrainedConfig (366-567)
tests/unittest/_torch/test_pytorch_model_engine.py (1)
  • head_dim (33-34)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5)
tensorrt_llm/executor/request.py (1)
  • path (48-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • SplitDimension (183-191)
tensorrt_llm/_torch/modules/linear.py (2)
  • split_dim (49-50)
  • Linear (1499-1708)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
  • run_sharding_pattern_detection_test (228-242)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)
  • _run_pattern_detection_job (69-116)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ShardingConfigSource (16-20)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (49-50)
tensorrt_llm/logger.py (1)
  • warning (131-132)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
  • _append_simple_shard (100-121)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (5)
  • ModelFactory (23-228)
  • ShardingConfigSource (16-20)
  • register (235-240)
  • get_sharding_config_source (120-126)
  • get_sharding_config (108-110)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • filtered_nodes (209-251)
  • identify_regions_between_residuals (306-359)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6)
  • SplitDimension (183-191)
  • get_predefined_config (567-568)
  • ShardingConfig (482-568)
  • TPShardingInfo (230-265)
  • BMMShardingInfo (268-367)
  • EPShardingInfo (457-479)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
  • get_sharding_config_source (203-209)
🔇 Additional comments (16)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (2)

49-56: Unifying sharding detection surface looks good

Consolidating the prior TP/EP/BMM detectors into detect_sharding simplifies config and reduces ordering pitfalls. The explicit use_sharding_from_factory gate also makes the factory-driven path clear.


49-52: Ensure consistent naming: use_sharding_from_factory

All occurrences in code, config, and tests use the key use_sharding_from_factory—there are no instances of use_sharding_from_config. To avoid confusion, please update the PR description and any related documentation to refer to use_sharding_from_factory instead of use_sharding_from_config.

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

210-213: Signature expansion to support predicate-based filtering: LGTM

Allowing a callable predicate in filtered_nodes is a useful and backward-compatible enhancement.

tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

30-32: Ignore factory keyword concern on SharedConfig

SharedConfig subclasses Pydantic’s BaseModel without a Config forbidding extras, so unknown fields like factory are silently ignored rather than raising. Moreover, all transforms receive factory as a separate argument (e.g. transform(gm, cm, factory, shared_config)), so no code depends on shared_config.factory. You can safely leave the constructor call as-is.

Likely an incorrect or invalid review comment.

tensorrt_llm/_torch/auto_deploy/models/factory.py (1)

16-21: Enum addition looks good

ShardingConfigSource with HUGGINGFACE and UNKNOWN is clear and leaves room for future sources (e.g., vendor/model-zoo specific).

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3)

46-48: Configuration key consistency looks good

The renaming from "detect_ep_shard" to "detect_sharding" is consistent with the broader PR changes, and the addition of "use_sharding_from_factory": False correctly preserves the existing test behavior.


104-106: Duplicate configuration is appropriate

The configuration is consistent with the first instance (line 46-48) and appropriately placed for the pattern detection test.


120-120: Test parameter update aligns with early return logic

Changing num_experts from [3, 8] to [4, 8] ensures that with device_count values (typically 1, 2, 4, 8), the world_size <= num_experts condition is satisfied, avoiding the early return case introduced at line 27-29.

tensorrt_llm/_torch/auto_deploy/models/hf.py (2)

203-209: LGTM! Source identification is properly implemented

The method correctly returns ShardingConfigSource.HUGGINGFACE to identify this factory as a HuggingFace source for sharding configuration.


363-375: Image-text model configuration override is well structured

The method properly calls the parent implementation first and then specializes by checking for text_config attributes, which is the appropriate pattern for multi-modal models.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

256-260: Confirm SplitDimension semantics in tests

The change from SplitDimension.ROW (1) to SplitDimension.COLUMN (0) flips the actual split axis in your TP sharding tests. Please ensure that:

  • TPShardingInfo.apply (and any downstream logic) indeed interprets split_dim=0 as the intended column-wise split.
  • All assertions in
    • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py lines 256–260,
    • 277–281, and
    • 299
      have been updated to reflect splitting on dimension 0 rather than 1.

If the apply method (or its helpers) still expects a row-wise split at index 1, adjust either the tests or the implementation so they remain in sync.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

240-251: Validation logic correctly aligns with the new dimension semantics

The validation now correctly enforces that COLUMN split (dim 0) only supports all_gather and ROW split (dim 1) only supports all_reduce, which aligns with tensor parallelism best practices.


571-592: Function implementation is well-structured

The _append_simple_shard helper correctly creates TPShardingInfo entries with COLUMN split dimension and all_gather operation, which is consistent with the simple sharding strategy.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

187-193: Good architectural decision on factory-driven sharding

The implementation correctly prioritizes factory-provided sharding configuration when available, providing a clean override mechanism while preserving the autodeploy heuristics as a fallback.


114-114: Dimension change for simple sharding is correct

The switch to SplitDimension.COLUMN intentionally maps to tensor axis 0 (as confirmed by the enum comments, the validate logic in sharding_utils.py, and the existing multigpu tests in test_tp_sharding.py which expect simple shards to use dim 0 with all_gather). No further changes are required.


503-510: split_dim assignment is correct and matches enum values
The loop index i (0 or 1) directly corresponds to SplitDimension.COLUMN and SplitDimension.ROW respectively, and Pydantic will coerce the integer into the enum. This aligns with the intended two-way shard semantics. No change is required (you may optionally wrap i with SplitDimension(i) for readability).

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@suyoggupta
Copy link
Collaborator

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

1-1: Missing NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
 """Tests for basic graph sharding."""
♻️ Duplicate comments (1)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

10-11: Use proper package imports instead of local helpers at repo root

Importing from _dist_test_utils and _graph_test_helpers assumes ad-hoc PYTHONPATH configuration. Prefer importing via the tests package path.

-from _dist_test_utils import get_device_counts
-from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm
+from tests.unittest._torch.auto_deploy._utils_test._dist_test_utils import get_device_counts
+from tests.unittest._torch.auto_deploy._utils_test._graph_test_helpers import (
+    run_sharding_pattern_detection_test,
+    run_test_transformed_gm,
+)
🧹 Nitpick comments (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)

22-45: Unused base_model_tp_plan / predefined_config — config-driven path likely not exercised

These dicts are declared but never used. As-is, parametrizing from_config=True won’t cause the tests to verify “apply sharding from model config” unless the pipeline is wired to consume this plan for these toy models (MLP, nn.Linear, GQA_Block), which typically aren’t backed by a ModelFactory.

Action options:

  • Wire them into the test when from_config is True (preferred), e.g., inject the plan into the optimizer’s shared sharding config prior to calling optimizer(None, gm), or monkeypatch the factory/config provider for these models.
  • Or remove them if you rely solely on factory-backed models elsewhere for config-driven sharding tests.

If the intent is to validate the new “read sharding from config” path, please hook these into the detection stage and assert that the detected transforms exactly match the plan. I can help sketch a minimal monkeypatching approach for the test.

Also applies to: 47-50


252-257: Prefer canonical DistOp/enum over string literals and clarify SplitDimension comment

  • Using string literals for dist_op ("all_reduce", "all_gather") is brittle if the detector returns enums or ops; prefer the same type the detector emits (e.g., DistOp enum or the concrete torch operator) to avoid fragile equality.
  • Minor: the comment “Simple shard uses dim=0” is ambiguous without mapping to SplitDimension; clarify to avoid confusion.

If DistOp enum is available:

-                        dist_op = "all_reduce"
+                        dist_op = DistOp.ALL_REDUCE
-                        dist_op = "all_gather"
+                        dist_op = DistOp.ALL_GATHER

Comment tweak for clarity:

-                            split_dim=SplitDimension.COLUMN,  # Simple shard uses dim=0
+                            split_dim=SplitDimension.COLUMN,  # Simple shard uses COLUMN (dim=0)

Also applies to: 273-277, 295-301


318-319: Drop or gate debug prints to keep test output clean

These prints will spam CI logs. Either remove or guard behind a verbosity flag/env var.

-    print(f"detected_transformations: {detected_transformations}")
-    print(f"expected_transformations: {expected_transformations}")
+    # if os.environ.get("TP_SHARDING_DEBUG") == "1":
+    #     print(f"detected_transformations: {detected_transformations}")
+    #     print(f"expected_transformations: {expected_transformations}")

374-375: Remove main runner from the test module

Test files generally shouldn’t be executable scripts. This block can cause confusion and accidental execution paths in some environments.

Apply:

-if __name__ == "__main__":
-    _run_pattern_detection_job(nn.Linear, False, 0, 8, False)
+# Intentionally no __main__ block in test modules.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6a61d1c and 56df6cf.

📒 Files selected for processing (1)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (11 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (7)

114-114: Threading from_config through _run_job looks good

Plumbing this flag into the job runner is the right direction.


182-186: Flag name mismatch with PR objective — ensure the pipeline actually toggles the right path

PR objective mentions ad_config.use_sharding_from_config, but the tests set detect_sharding.use_sharding_from_factory. If the implementation expects use_sharding_from_config, these tests won’t exercise the intended path.

Safest fix: set both keys in the test config to cover either naming, until the codebase is fully unified.

             "detect_sharding": {
                 "stage": "sharding",
-                "use_sharding_from_factory": from_config,
+                "use_sharding_from_factory": from_config,
+                "use_sharding_from_config": from_config,
             },

Apply the same change in both optimizer configs in this file.

Also applies to: 307-311


326-327: from_config parametrization may not validate the config-driven path for toy models

For nn.Linear and MLP that aren’t factory-backed, from_config=True likely won’t load a model config, so the tests may either:

  • silently fallback to detection (not testing the new path), or
  • fail nondeterministically depending on the implementation.

Consider skipping from_config=True for models without a factory-config, or inject a sharding plan for them to truly test the config path.


335-341: LGTM on threading from_config into test_sharding

The signature and spawn wiring align with the new test matrix.


350-351: Duplicate note: from_config parametrization for pattern-detection

Same concern as above — ensure from_config=True actually consumes a config for the tested model, otherwise this branch isn’t validated.


360-365: LGTM on test_sharding_pattern_detection signature updates

The new parameter is correctly threaded.


371-371: Routing from_config into the pattern-detection job looks correct

Invocation aligns with the new signature.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15640 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15640 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11773 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15769 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15769 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11855 completed with status: 'FAILURE'

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

177-186: head_dim detection is too narrow; use head_size or derive from hidden_size/num_attention_heads.

As implemented, many configs won’t have head_dim; validate_config then invalidates the factory config, breaking the “use sharding from config” path. Prefer head_size when present (see PretrainedConfig) or compute it from hidden_size // num_attention_heads. Keep explicit overrides if provided.

Apply this refactor:

 def _set_sharding_config(self, model_config: PretrainedConfig):
     """Set the sharding config for the model."""
-    self._sharding_config["head_dim"] = 1
-    if hasattr(model_config, "base_model_tp_plan"):
-        self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
-    if hasattr(model_config, "head_dim"):
-        self._sharding_config["head_dim"] = model_config.head_dim
-    if hasattr(model_config, "num_hidden_layers"):
-        self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
+    # TP plan (if provided by the model config)
+    if hasattr(model_config, "base_model_tp_plan"):
+        self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
+
+    # Determine head_dim robustly: prefer explicit head_dim/head_size, else derive.
+    head_dim = None
+    if hasattr(model_config, "head_dim"):
+        head_dim = getattr(model_config, "head_dim")
+    elif hasattr(model_config, "head_size"):
+        head_dim = getattr(model_config, "head_size")
+    elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"):
+        try:
+            head_dim = model_config.hidden_size // model_config.num_attention_heads
+        except Exception:
+            head_dim = None
+    self._sharding_config["head_dim"] = head_dim if head_dim is not None else 1
+
+    # Propagate layer count if available
+    if hasattr(model_config, "num_hidden_layers"):
+        self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
🧹 Nitpick comments (9)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

1-1: Add NVIDIA copyright header (2025).

Per coding guidelines, prepend the NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (8)

1-1: Add NVIDIA copyright header (2025).

Per coding guidelines, prepend the NVIDIA copyright header to all source files.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

100-121: Fix comment to reflect SplitDimension choice (avoid row/column ambiguity).

The code uses SplitDimension.COLUMN (value 0) while the comment says “row_split (dim 0)”. Given the ROW/COLUMN naming inversion wrt W vs W^T, this is confusing. Clarify the comment to match the actual enum used.

Apply this small comment update:

-    # --> row_split (dim 0 of weight) + all_gather (dim -1 of output)
+    # --> column split on weight (SplitDimension.COLUMN / dim 0) + all_gather (dim -1 of output)

171-186: Guard against missing 'source' key in factory config to avoid KeyError.

If an external ModelFactory returns a dict without 'source', this will KeyError. Use .get with UNKNOWN fallback.

Apply this guard:

-        shared_config.sharding_config.predefined_config = (
-            factory.get_sharding_config() if factory else {}
-        )
-        shared_config.sharding_config.factory_source = (
-            shared_config.sharding_config.predefined_config["source"]
-            if factory
-            else ShardingConfigSource.UNKNOWN
-        )
+        shared_config.sharding_config.predefined_config = factory.get_sharding_config() if factory else {}
+        source = (
+            shared_config.sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN)
+            if factory
+            else ShardingConfigSource.UNKNOWN
+        )
+        shared_config.sharding_config.factory_source = source

219-223: Return type annotation mismatch (returns TransformInfo but annotated as None).

Static type mismatch; also update docstring “Returns” section accordingly.

Apply:

-def detect_sharding_from_factory_config(
+def detect_sharding_from_factory_config(
     gm: GraphModule,
     sharding_config: ShardingConfig,
-) -> None:
+) -> TransformInfo:
@@
-    Create sharding transformations from the predefined config.
+    Create sharding transformations from the predefined config and return summary info.
@@
-    """
+    Returns:
+        TransformInfo: Aggregated info for detected and appended transforms.
+    """

Also applies to: 332-337


340-359: TP detector: return type annotation mismatch (returns TransformInfo).

Fix the signature to match usage in Sharding._apply.

Apply:

-def detect_column_row_shard(
+def detect_column_row_shard(
     gm: GraphModule,
     sharding_config: ShardingConfig,
-) -> None:
+) -> TransformInfo:

Also applies to: 515-518


521-598: DP BMM detector: return type mismatch and dead branch for remainder.

  • Signature should return TransformInfo.
  • Since you early-continue when remainder != 0, the rank<remainder branch is unreachable. Simplify the index computation for clarity.

Apply:

-def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None:
+def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo:
@@
-        # Calculate start and end indices for this rank
-        if rank < remainder:
-            start_idx = rank * (base_size + 1)
-            end_idx = start_idx + base_size + 1
-        else:
-            start_idx = remainder + rank * base_size
-            end_idx = start_idx + base_size
+        # Calculate start and end indices for this rank (remainder==0 here)
+        start_idx = rank * base_size
+        end_idx = start_idx + base_size

600-634: EP detector: return type mismatch (returns TransformInfo).

Align signature with returned value and usage.

Apply:

-def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None:
+def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo:

265-331: Use a helper to resolve the weight module path instead of assuming args[1] is always a get_attr
Accessing lin_node.args[1].target works for pure aten.linear calls, but fused or quantized variants (and any call_module‐based Linear) may wrap or reorder the weight argument. Introduce a small utility that walks from the second arg back to its originating get_attr node:

def _resolve_weight_target(node: Node) -> Optional[str]:
    # start from the weight argument
    candidate = node.args[1]
    # unwind any pack/unpack or other call_function wrappers
    while isinstance(candidate, Node) and candidate.op != "get_attr":
        # assume the true weight is the first arg of the wrapper
        candidate = candidate.args[0]
    return candidate.target if isinstance(candidate, Node) and candidate.op == "get_attr" else None

# then in your loop:
module_name = _resolve_weight_target(lin_node)
if module_name is None:
    ad_logger.warning(f"Unable to resolve weight for node {lin_node.name}, skipping.")
    continue

– Replace all direct uses of lin_node.args[1].target with this helper
– This will cover pure, fused, quantized, and call_module‐traced linears uniformly

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 56df6cf and 02eac6a.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/auto_deploy/models/factory.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/auto_deploy/models/factory.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
🧬 Code Graph Analysis (2)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
  • ShardingConfigSource (16-20)
  • register (228-233)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • filtered_nodes (209-251)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (8)
  • SplitDimension (183-191)
  • validate_config (518-564)
  • get_predefined_config (566-567)
  • ShardingConfig (482-567)
  • TPShardingInfo (230-265)
  • _append_simple_shard (570-591)
  • BMMShardingInfo (268-367)
  • EPShardingInfo (457-479)
tensorrt_llm/_torch/auto_deploy/models/hf.py (2)
tensorrt_llm/_torch/auto_deploy/models/factory.py (3)
  • ModelFactory (23-221)
  • ShardingConfigSource (16-20)
  • model (53-55)
tensorrt_llm/models/modeling_utils.py (1)
  • PretrainedConfig (366-567)
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)

32-32: Importing ShardingConfigSource here is correct.

This aligns the factory with the sharding config source lifecycle.


97-99: Good: sharding source marked as HuggingFace.

Setting the source early ensures downstream validation can trust the origin.


167-169: Good: initialize sharding config from model.config post-init.

This enables the factory-driven path to be used by the transform when requested.


177-186: Convert Enum to a primitive in _sharding_config before serialization

Storing raw ShardingConfigSource enum members in the _sharding_config dict will break JSON serialization (e.g. json.dumps) unless downstream code knows how to handle Enums. Replace these assignments with the enum’s .value (or .name) so only primitives are stored:

• In tensorrt_llm/_torch/auto_deploy/models/factory.py

  • Line ~16:
  • self._sharding_config["source"] = ShardingConfigSource.UNKNOWN
  • self._sharding_config["source"] = ShardingConfigSource.UNKNOWN.value
  • Line 50: same change.

• In tensorrt_llm/_torch/auto_deploy/models/hf.py

  • Line 98:
  • self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
  • self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE.value
    
    

This ensures get_sharding_config() returns only JSON‐serializable primitives.

Likely an incorrect or invalid review comment.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

28-36: Imports look correct.

Good to centralize ModelFactory and node/sharding utilities here.


124-129: Config surface extension LGTM.

Adding use_sharding_from_factory to the transform config wires the factory-driven path cleanly.


191-197: Factory-driven early return makes sense.

Skips heuristics when a valid sharding plan is supplied and requested.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16327 [ run ] triggered by Bot

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16328 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16327 [ run ] completed with state ABORTED

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)

1-3: Add NVIDIA copyright header (2025) at top of file.

Per repository guidelines, prepend the standard NVIDIA copyright header.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/llm_args.py (2)

168-171: Constrain values and normalize duplicates for sharding_dims.

Constrain to known dimension names and dedupe while preserving order to avoid subtle misconfigurations.

Apply stricter typing:

-    sharding_dims: List[str] = Field(
+    sharding_dims: List[Literal["tp", "ep", "dp"]] = Field(
         default=["tp", "ep", "dp"],
-        description="The dimensions to use for sharding. Allowed values: 'tp', 'ep', 'dp'.",
+        description="The dimensions to use for sharding. Allowed values: 'tp', 'ep', 'dp'.",
     )

Then add this validator within AutoDeployConfig (outside the changed hunk):

@field_validator("sharding_dims")
@classmethod
def _normalize_sharding_dims(cls, v: List[str]) -> List[str]:
    # remove duplicates but preserve order
    seen = set()
    deduped = []
    for d in v:
        if d not in seen:
            seen.add(d)
            deduped.append(d)
    return deduped

162-167: Clarify factory sharding behavior in llm_args.py

The use_sharding_from_factory field is already wired through transform/library/sharding.py (lines 195–197) to bypass detection when set and only applies if the factory supplies a non‐empty plan; otherwise it falls back to regular auto-detection. Tests in tests/unittest/_torch/auto_deploy/.../test_tp_sharding.py (lines 184, 309) and the other sharding tests cover both False and True cases.

– File: tensorrt_llm/_torch/auto_deploy/llm_args.py (lines 162–167)
– Replace the Field description with:

     use_sharding_from_factory: bool = Field(
         default=False,
-        description="If True, use sharding from the model factory. If False, use sharding from the "
-        "AutoDeployConfig.",
+        description=(
+            "If True, TP/EP/DP sharding is taken from the model factory (e.g., `base_model_tp_plan`), bypassing "
+            "auto-detection. If the factory returns no plan (i.e., `predefined_config` is empty or `None`), "
+            "factory sharding is ignored and auto-detection is used. If False, sharding comes from AutoDeployConfig."
+        ),
     )
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

55-60: Detect_sharding plumbing is wired correctly and fully exercised by existing tests

  • The detect_sharding block in config/default.yaml (lines 55–60) is consumed by the Sharding transform in
    tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py, which reads
    simple_shard_only, use_sharding_from_factory, and sharding_dims and applies them in the TP/EP/BMM stages.
  • Unit tests in
    tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/
    cover each combination of use_sharding_from_factory and sharding_dims for "tp", "ep", and "bmm" (see
    test_tp_sharding.py, test_ep_sharding.py, test_bmm_sharding.py).

No plumbing changes are required.

Optional nitpick: add a brief inline comment above the detect_sharding section in config/default.yaml documenting

  • the allowed values for sharding_dims ("tp", "ep", "bmm") and
  • that setting use_sharding_from_factory: true takes precedence over the heuristic-based detection.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6720933 and 33d267f.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/llm_args.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/auto_deploy/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/auto_deploy/llm_args.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
  • Field (67-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/llm_args.py (2)

168-171: Constrain and normalize sharding_dims; prevent invalid values and duplicates.

Current type List[str] allows typos, duplicates, and unexpected values. Constraining the item type and adding a validator will harden the config interface and make downstream logic simpler.

Minimal, backwards-compatible tightening in-place:

-    sharding_dims: List[str] = Field(
-        default=["tp", "ep", "dp"],
-        description="The sharding methods to apply by the heuristic sharding stage.",
-    )
+    sharding_dims: List[Literal["tp", "ep", "dp"]] = Field(
+        default=["tp", "ep", "dp"],
+        description=(
+            "Heuristic sharding dimensions to consider (subset of: 'tp', 'ep', 'dp'). "
+            "Ignored when use_sharding_from_factory=True."
+        ),
+        json_schema_extra={"allowed": ["tp", "ep", "dp"]},
+    )

Add this validator in AutoDeployConfig to deduplicate, normalize case, and fail fast on invalid entries (place alongside existing validators):

@field_validator("sharding_dims", mode="before")
@classmethod
def _validate_sharding_dims(cls, value: Any) -> List[str]:
    if value is None:
        # Preserve current semantics: default is all dims.
        return ["tp", "ep", "dp"]
    if not isinstance(value, (list, tuple)):
        raise TypeError("sharding_dims must be a list/tuple of strings")
    allowed = {"tp", "ep", "dp"}
    seen = set()
    out: List[str] = []
    for d in value:
        if not isinstance(d, str):
            raise TypeError(f"sharding_dims entries must be strings, got: {type(d).__name__}")
        d_norm = d.lower()
        if d_norm not in allowed:
            raise ValueError(f"Invalid sharding dim '{d}'. Allowed: {sorted(allowed)}")
        if d_norm not in seen:
            seen.add(d_norm)
            out.append(d_norm)
    return out

Consider adding unit tests that:

  • Assert invalid values (e.g., ["tp", "foo"]) raise ValidationError.
  • Assert duplicates (["tp", "TP", "ep"]) deduplicate to ["tp", "ep"].
  • Assert flag interaction: when use_sharding_from_factory=True, sharding_dims is ignored by the detector.

162-166: Align PR description & clarify use_sharding_from_factory semantics

  • Naming consistency verified: the entire codebase (including tests and default YAML) uses use_sharding_from_factory; no occurrences of use_sharding_from_config were found. Please update the PR description to reference use_sharding_from_factory to avoid confusion.
  • Enhance the Field description in tensorrt_llm/_torch/auto_deploy/llm_args.py to make the flag’s behavior explicit. Proposed diff:
--- a/tensorrt_llm/_torch/auto_deploy/llm_args.py
+++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py
@@ -160,7 +160,11 @@ class AutoDeployConfig(BaseModel):
     use_sharding_from_factory: bool = Field(
         default=False,
-        description="If True, use sharding from the model factory. If False, use sharding from the AutoDeployConfig.",
+        description=(
+            "If True, bypass heuristic detection and apply sharding provided by the model "
+            "factory/config (e.g., base_model_tp_plan). "
+            "If False, run detect_sharding with sharding_dims."
+        ),
     )
  • Optional: mirror this updated description on the use_sharding_from_factory field in tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py for documentation consistency.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 33d267f and 8ac9aba.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/llm_args.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/auto_deploy/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/auto_deploy/llm_args.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
  • Field (67-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16328 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12276 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16377 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16377 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #12305 completed with status: 'ABORTED'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16399 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16399 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12324 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@suyoggupta
Copy link
Collaborator

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16478 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16478 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #16399 for commit cbe821b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AutoDeploy <NV> AutoDeploy Backend

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

4 participants