-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-6342][feat] TP Sharding read from the model config #6972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
📝 WalkthroughWalkthroughConsolidates sharding detection into a single Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant IO as InferenceOptimizer
participant DS as detect_sharding
participant MF as ModelFactory
participant Det as Detectors (TP/EP/DP)
participant EX as sharding_transform_executor
participant GM as GraphModule
IO->>DS: apply(config, GM)
alt use_sharding_from_factory && predefined_config present
DS->>MF: get_sharding_config()
DS->>GM: detect_sharding_from_factory_config(GM, sharding_config)
DS-->>IO: TransformInfo (factory-derived)
else heuristics path
DS->>Det: run TP / EP / DP detectors (filtered by sharding_dims)
Det-->>DS: TransformInfos
DS-->>IO: Aggregated TransformInfo
end
IO->>EX: apply(GM, TransformInfo)
alt run_shape_prop enabled
EX->>GM: run shape propagation
end
EX-->>IO: transformed GraphModule
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
54-56: ShardingTransformExecutor doesn’t consume run_shape_propYou’ve enabled
run_shape_prop: truefor thesharding_transform_executorindefault.yaml, butShardingTransformExecutor._applynever readsself.config.run_shape_prop(no shape-propagation is performed). Please update the executor to respect this flag or remove it from the config:• File to update
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
– In_apply(...), add a conditional shape-propagation step whenself.config.run_shape_propis true (e.g., callcanonicalize_graph(gm, shape_prop=True)around the transformation loop).
• Config cleanup
If shape propagation isn’t supported yet, removerun_shape_prop: truefrom thesharding_transform_executorsection indefault.yamlto avoid confusion.
🧹 Nitpick comments (17)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
210-213: Tighten type hints to avoid ambiguity between ops and predicatesOperatorLike currently includes Callable, and target repeats Callable, which is confusing for type-checkers and readers. Consider introducing a dedicated NodePredicate alias and narrowing OperatorLike to only op types.
Proposed refinement outside the selected lines (illustrative):
# near lines ~28-31 from typing import Callable NodePredicate = Callable[[Node], bool] OpOrOverload = Union[OpOverloadPacket, OpOverload] OperatorLike = OpOrOverload # remove Callable from here # function signature (lines 210-213) def filtered_nodes( nodes: Iterable[Node], target: Union[NodePredicate, Union[OperatorLike, Iterable[OperatorLike]]] = None, ops: Union[OperatorLike, Iterable[OperatorLike]] = None, ) -> Iterable[Node]: ...
241-251: Guard against silent no-op when both target and ops are NoneIf both target and ops are None, filtered_nodes silently yields nothing. Being explicit will save debugging time.
Apply this diff to validate inputs:
- else: - # Handle the case where target or ops contains operations - operations = ops if ops is not None else target - for node in nodes: - if is_op(node, operations): - yield node + else: + # Handle the case where target or ops contains operations + operations = ops if ops is not None else target + if operations is None: + raise ValueError("filtered_nodes requires either a predicate (target) or ops to be set.") + for node in nodes: + if is_op(node, operations): + yield nodetests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3)
67-70: Explicit False is redundantuse_sharding_from_factory defaults to False in default.yaml. Keeping it explicit is fine for clarity, but can be omitted to reduce verbosity.
- "detect_sharding": { - "stage": "sharding", - "use_sharding_from_factory": False, - }, + "detect_sharding": { + "stage": "sharding", + },
128-131: Same minor nit: explicit False can be omittedTo reduce noise in tests, you can rely on the default.
- "detect_sharding": { - "stage": "sharding", - "use_sharding_from_factory": False, - }, + "detect_sharding": { + "stage": "sharding", + },
67-70: Consider adding a test for factory-driven shardingSince this PR introduces reading sharding plans from a model factory/config, consider adding a focused test where use_sharding_from_factory=True and a minimal FakeFactory returns a BMM plan. It will protect against regressions in the new path.
I can draft a small test fixture with a FakeFactory exposing get_sharding_config and get_sharding_config_source to exercise this path. Want me to open a follow-up?
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
49-49: Provide a typed, validated setter for sharding configRight now factories must mutate _sharding_config directly. Expose a small helper to set it with optional validation so different factories follow a consistent contract.
Outside the selected lines, add something like:
def _set_sharding_config(self, cfg: Dict[str, Any]) -> None: # Optionally validate minimal keys, e.g., "tp_plan" when available self._sharding_config = copy.deepcopy(cfg or {})
120-127: Docstring clarity: this returns the sharding config sourceNit: “source of the model factory” reads oddly. Clarify it’s the source of the sharding config.
- def get_sharding_config_source(self) -> ShardingConfigSource: - """Return the source of the model factory. - - Returns: - The source identifier for this model factory. - """ + def get_sharding_config_source(self) -> ShardingConfigSource: + """Return the source of the sharding configuration. + + Returns: + The origin identifier for the sharding config provided by this factory. + """ return ShardingConfigSource.UNKNOWNtests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)
26-29: Validateworld_size > num_expertsin the sharding implementationThe current EP-sharding utility in
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
does not guard againstworld_size > num_experts. In that caseexperts_per_rank = num_experts // world_sizeyields zero, leading to empty partitions or downstream errors. I recommend adding an early check at the top of
_insert_sharded_moe:def _insert_sharded_moe(gm, node, *args): # args = (selected_experts, final_scales, ..., world_size, rank, num_experts, ...) if world_size > num_experts: raise ValueError( f"world_size ({world_size}) cannot exceed num_experts ({num_experts})" ) # existing logic…— Location to update:
• tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
• Function:_insert_sharded_moetensorrt_llm/_torch/auto_deploy/models/hf.py (1)
174-183: Validate model_config attributes before accessingThe method assumes that the HuggingFace model config may not have the expected attributes. This is handled correctly with
hasattr()checks, but consider adding debug logging when expected attributes are missing to aid troubleshooting.def _set_sharding_config(self, model_config: PretrainedConfig): """Set the sharding config for the model.""" self._sharding_config["head_dim"] = 1 if hasattr(model_config, "base_model_tp_plan"): self._sharding_config["tp_plan"] = model_config.base_model_tp_plan + else: + ad_logger.debug("base_model_tp_plan not found in model config") if hasattr(model_config, "head_dim"): self._sharding_config["head_dim"] = model_config.head_dim if hasattr(model_config, "num_hidden_layers"): self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layerstests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)
26-49: Document the base_model_tp_plan structure and usageThe
base_model_tp_plandictionary defines sharding strategies but lacks documentation about what each value means and why some entries are commented out.base_model_tp_plan = { + # Tensor parallel sharding plan mapping module names to sharding strategies: + # - "colwise": Split along columns (dimension 0) without communication + # - "rowwise": Split along rows (dimension 1) with all-reduce + # - "gather": Simple shard with row split and all-gather "q_proj": "colwise", "k_proj": "colwise", "v_proj": "colwise", "o_proj": "rowwise", "gate_proj": "colwise", "up_proj": "colwise", "down_proj": "rowwise", "linear1": "colwise", "linear2": "rowwise", "linear": "gather", - # "input_layernorm.weight": "sequence_parallel", - # "post_attention_layernorm.weight": "sequence_parallel", - # "norm.weight": "sequence_parallel", - # "shared_expert.gate_proj": "local_colwise", - # "shared_expert.up_proj": "local_colwise", - # "shared_expert.down_proj": "local_rowwise", - # "experts.gate_up_proj": "local_packed_rowwise", - # "experts.down_proj": "local_colwise", - # "experts": "local", + # TODO: The following strategies require hybrid EP+TP and/or SP support: + # "input_layernorm.weight": "sequence_parallel", + # "post_attention_layernorm.weight": "sequence_parallel", + # "norm.weight": "sequence_parallel", + # "shared_expert.gate_proj": "local_colwise", + # "shared_expert.up_proj": "local_colwise", + # "shared_expert.down_proj": "local_rowwise", + # "experts.gate_up_proj": "local_packed_rowwise", + # "experts.down_proj": "local_colwise", + # "experts": "local", "feed_forward": "gather", "self": "gather", "weight": "gather", }
322-323: Consider using proper logging instead of print statementsDebug print statements should be replaced with proper logging for better control over output in production environments.
- print(f"detected_transformations: {detected_transformations}") - print(f"expected_transformations: {expected_transformations}") + ad_logger.debug(f"detected_transformations: {detected_transformations}") + ad_logger.debug(f"expected_transformations: {expected_transformations}")Import ad_logger at the top of the file:
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
378-379: Remove or properly integrate the main blockThe
__main__block appears to be debug code that should either be removed or converted to a proper test case.-if __name__ == "__main__": - _run_pattern_detection_job(nn.Linear, False, 0, 8, False) +# Note: To run pattern detection manually for debugging: +# pytest test_tp_sharding.py::test_sharding_pattern_detection -k "Linear" --verbosetensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
186-191: Critical comment about dimension mapping needs clarificationThe comment explains that COLUMN/ROW names reflect HuggingFace notation but the actual split is counterintuitive (COLUMN splits dimension 0). This is a critical detail that affects correctness. Consider adding an example to make this clearer.
# NOTE: The names COLUMN/ROW reflect the hugging face # base_tp_plan sharding notation, but since we assume Y = W @ X^T, # when splitting weight matrix W^T across columns, the actual split # is over dimension 0 + # Example: For weight W of shape [out_features, in_features]: + # - COLUMN (0): splits out_features dimension + # - ROW (1): splits in_features dimension COLUMN = 0 ROW = 1
519-527: Improve error message for unsupported factory sourcesThe error message has a typo and could be more informative about which sources are supported.
if self.factory_source != ShardingConfigSource.HUGGINGFACE: ad_logger.warning( - "Sharding config is is currently only " + "supported for HuggingFace. Skipping." + f"Sharding config from {self.factory_source.value} is currently only supported for HuggingFace sources. Skipping." )
547-564: Consider making allowed sharding values a class constantThe allowed values for tp_plan should be defined as a class constant for better maintainability and reusability.
+ ALLOWED_TP_PLAN_VALUES = { + "colwise", # row split and no collective + "rowwise", # column split and all-reduce + "gather", # simple shard (row + all_gather) + # TODO: remaining values are not supported yet. + # They require hybrid EP+TP and/or SP support. + # "sequence_parallel", # sequence parallelism + # "local_colwise", + # "local_rowwise", + # "local_packed_rowwise", + # "local", + } + def validate_config(self) -> bool: # ... existing validation code ... values = set(tp_plan.values()) - allowed_values = { - "colwise", # row split and no collective - "rowwise", # column split and all-reduce - "gather", # simple shard (row + all_gather) - # TODO: remaining values are not supported yet. - # They require hybrid EP+TP and/or SP support. - # "sequence_parallel", # sequence parallelism - # "local_colwise", - # "local_rowwise", - # "local_packed_rowwise", - # "local", - } - if not values.issubset(allowed_values): + if not values.issubset(self.ALLOWED_TP_PLAN_VALUES): ad_logger.warning("Sharding config contains invalid values. Skipping.")tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
273-280: Improve regex pattern construction for better readabilityThe regex pattern construction using
@as a placeholder is clever but could be more readable with a direct approach.# use regex to find if module_name matches any of the keys in sharding_config for key in tp_plan.keys(): - pattern_string = "*" + key + "*" - # convert it to regex. Escape dots, replace * with .* - # First, we substitute * with an unlikely character, e.g. @ - # Then we escape dots, and finally we replace @ with .* - pattern_string = pattern_string.replace("*", "@") - pattern_regex = re.escape(pattern_string).replace("@", ".*") + # Convert glob pattern to regex: * becomes .*, other chars are escaped + pattern_string = f"*{key}*" + # Replace * with placeholder before escaping + parts = pattern_string.split("*") + pattern_regex = ".*".join(re.escape(part) for part in parts) if re.match(pattern_regex, module_name):
306-311: Consolidate TODO comments for unsupported featuresThe warning messages for unsupported features could be more informative by including what configuration value was encountered.
elif "sequence" in config: # TODO: Sequence parallelism is not supported yet. - ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") + ad_logger.warning(f"Sequence parallelism config '{config}' for '{module_name}' is not supported yet. Skipping.") elif "local" in config: # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. - ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") + ad_logger.warning(f"Local EP+TP sharding config '{config}' for '{module_name}' is not supported yet. Skipping.")
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (10)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/models/factory.py(5 hunks)tensorrt_llm/_torch/auto_deploy/models/hf.py(4 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py(6 hunks)tensorrt_llm/_torch/auto_deploy/transform/optimizer.py(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py(4 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py(4 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py(12 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.pytensorrt_llm/_torch/auto_deploy/transform/optimizer.pytensorrt_llm/_torch/auto_deploy/utils/node_utils.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.pytensorrt_llm/_torch/auto_deploy/models/factory.pytensorrt_llm/_torch/auto_deploy/models/hf.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.pytensorrt_llm/_torch/auto_deploy/utils/sharding_utils.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.pytensorrt_llm/_torch/auto_deploy/transform/optimizer.pytensorrt_llm/_torch/auto_deploy/utils/node_utils.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.pytensorrt_llm/_torch/auto_deploy/models/factory.pytensorrt_llm/_torch/auto_deploy/models/hf.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.pytensorrt_llm/_torch/auto_deploy/utils/sharding_utils.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
🧬 Code Graph Analysis (7)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (2)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
SharedConfig(51-56)tensorrt_llm/mapping.py (1)
local_rank(399-400)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (1)
target(502-503)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
get_sharding_config_source(203-209)
tensorrt_llm/_torch/auto_deploy/models/hf.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (4)
ModelFactory(23-228)ShardingConfigSource(16-20)model(52-54)get_sharding_config_source(120-126)tensorrt_llm/models/modeling_utils.py (1)
PretrainedConfig(366-567)tests/unittest/_torch/test_pytorch_model_engine.py (1)
head_dim(33-34)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5)
tensorrt_llm/executor/request.py (1)
path(48-49)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
SplitDimension(183-191)tensorrt_llm/_torch/modules/linear.py (2)
split_dim(49-50)Linear(1499-1708)tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_sharding_pattern_detection_test(228-242)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)
_run_pattern_detection_job(69-116)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ShardingConfigSource(16-20)tensorrt_llm/_torch/modules/linear.py (1)
split_dim(49-50)tensorrt_llm/logger.py (1)
warning(131-132)tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
_append_simple_shard(100-121)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (5)
ModelFactory(23-228)ShardingConfigSource(16-20)register(235-240)get_sharding_config_source(120-126)get_sharding_config(108-110)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
filtered_nodes(209-251)identify_regions_between_residuals(306-359)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6)
SplitDimension(183-191)get_predefined_config(567-568)ShardingConfig(482-568)TPShardingInfo(230-265)BMMShardingInfo(268-367)EPShardingInfo(457-479)tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
get_sharding_config_source(203-209)
🔇 Additional comments (16)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (2)
49-56: Unifying sharding detection surface looks goodConsolidating the prior TP/EP/BMM detectors into detect_sharding simplifies config and reduces ordering pitfalls. The explicit use_sharding_from_factory gate also makes the factory-driven path clear.
49-52: Ensure consistent naming: use_sharding_from_factoryAll occurrences in code, config, and tests use the key
use_sharding_from_factory—there are no instances ofuse_sharding_from_config. To avoid confusion, please update the PR description and any related documentation to refer touse_sharding_from_factoryinstead ofuse_sharding_from_config.tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
210-213: Signature expansion to support predicate-based filtering: LGTMAllowing a callable predicate in filtered_nodes is a useful and backward-compatible enhancement.
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
30-32: Ignore factory keyword concern on SharedConfigSharedConfig subclasses Pydantic’s BaseModel without a Config forbidding extras, so unknown fields like
factoryare silently ignored rather than raising. Moreover, all transforms receivefactoryas a separate argument (e.g.transform(gm, cm, factory, shared_config)), so no code depends onshared_config.factory. You can safely leave the constructor call as-is.Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
16-21: Enum addition looks goodShardingConfigSource with HUGGINGFACE and UNKNOWN is clear and leaves room for future sources (e.g., vendor/model-zoo specific).
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3)
46-48: Configuration key consistency looks goodThe renaming from
"detect_ep_shard"to"detect_sharding"is consistent with the broader PR changes, and the addition of"use_sharding_from_factory": Falsecorrectly preserves the existing test behavior.
104-106: Duplicate configuration is appropriateThe configuration is consistent with the first instance (line 46-48) and appropriately placed for the pattern detection test.
120-120: Test parameter update aligns with early return logicChanging num_experts from
[3, 8]to[4, 8]ensures that withdevice_countvalues (typically 1, 2, 4, 8), theworld_size <= num_expertscondition is satisfied, avoiding the early return case introduced at line 27-29.tensorrt_llm/_torch/auto_deploy/models/hf.py (2)
203-209: LGTM! Source identification is properly implementedThe method correctly returns
ShardingConfigSource.HUGGINGFACEto identify this factory as a HuggingFace source for sharding configuration.
363-375: Image-text model configuration override is well structuredThe method properly calls the parent implementation first and then specializes by checking for
text_configattributes, which is the appropriate pattern for multi-modal models.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
256-260: Confirm SplitDimension semantics in testsThe change from
SplitDimension.ROW(1) toSplitDimension.COLUMN(0) flips the actual split axis in your TP sharding tests. Please ensure that:
- TPShardingInfo.apply (and any downstream logic) indeed interprets
split_dim=0as the intended column-wise split.- All assertions in
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py lines 256–260,
- 277–281, and
- 299
have been updated to reflect splitting on dimension 0 rather than 1.If the
applymethod (or its helpers) still expects a row-wise split at index 1, adjust either the tests or the implementation so they remain in sync.tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
240-251: Validation logic correctly aligns with the new dimension semanticsThe validation now correctly enforces that COLUMN split (dim 0) only supports all_gather and ROW split (dim 1) only supports all_reduce, which aligns with tensor parallelism best practices.
571-592: Function implementation is well-structuredThe
_append_simple_shardhelper correctly creates TPShardingInfo entries with COLUMN split dimension and all_gather operation, which is consistent with the simple sharding strategy.tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
187-193: Good architectural decision on factory-driven shardingThe implementation correctly prioritizes factory-provided sharding configuration when available, providing a clean override mechanism while preserving the autodeploy heuristics as a fallback.
114-114: Dimension change for simple sharding is correctThe switch to
SplitDimension.COLUMNintentionally maps to tensor axis 0 (as confirmed by the enum comments, thevalidatelogic insharding_utils.py, and the existing multigpu tests intest_tp_sharding.pywhich expect simple shards to use dim 0 withall_gather). No further changes are required.
503-510: split_dim assignment is correct and matches enum values
The loop indexi(0 or 1) directly corresponds toSplitDimension.COLUMNandSplitDimension.ROWrespectively, and Pydantic will coerce the integer into the enum. This aligns with the intended two-way shard semantics. No change is required (you may optionally wrapiwithSplitDimension(i)for readability).
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Outdated
Show resolved
Hide resolved
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
1-1: Missing NVIDIA copyright headerPer repository guidelines, prepend the current-year NVIDIA copyright header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Tests for basic graph sharding."""
♻️ Duplicate comments (1)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
10-11: Use proper package imports instead of local helpers at repo rootImporting from _dist_test_utils and _graph_test_helpers assumes ad-hoc PYTHONPATH configuration. Prefer importing via the tests package path.
-from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm +from tests.unittest._torch.auto_deploy._utils_test._dist_test_utils import get_device_counts +from tests.unittest._torch.auto_deploy._utils_test._graph_test_helpers import ( + run_sharding_pattern_detection_test, + run_test_transformed_gm, +)
🧹 Nitpick comments (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
22-45: Unused base_model_tp_plan / predefined_config — config-driven path likely not exercisedThese dicts are declared but never used. As-is, parametrizing from_config=True won’t cause the tests to verify “apply sharding from model config” unless the pipeline is wired to consume this plan for these toy models (MLP, nn.Linear, GQA_Block), which typically aren’t backed by a ModelFactory.
Action options:
- Wire them into the test when from_config is True (preferred), e.g., inject the plan into the optimizer’s shared sharding config prior to calling optimizer(None, gm), or monkeypatch the factory/config provider for these models.
- Or remove them if you rely solely on factory-backed models elsewhere for config-driven sharding tests.
If the intent is to validate the new “read sharding from config” path, please hook these into the detection stage and assert that the detected transforms exactly match the plan. I can help sketch a minimal monkeypatching approach for the test.
Also applies to: 47-50
252-257: Prefer canonical DistOp/enum over string literals and clarify SplitDimension comment
- Using string literals for dist_op ("all_reduce", "all_gather") is brittle if the detector returns enums or ops; prefer the same type the detector emits (e.g., DistOp enum or the concrete torch operator) to avoid fragile equality.
- Minor: the comment “Simple shard uses dim=0” is ambiguous without mapping to SplitDimension; clarify to avoid confusion.
If DistOp enum is available:
- dist_op = "all_reduce" + dist_op = DistOp.ALL_REDUCE- dist_op = "all_gather" + dist_op = DistOp.ALL_GATHERComment tweak for clarity:
- split_dim=SplitDimension.COLUMN, # Simple shard uses dim=0 + split_dim=SplitDimension.COLUMN, # Simple shard uses COLUMN (dim=0)Also applies to: 273-277, 295-301
318-319: Drop or gate debug prints to keep test output cleanThese prints will spam CI logs. Either remove or guard behind a verbosity flag/env var.
- print(f"detected_transformations: {detected_transformations}") - print(f"expected_transformations: {expected_transformations}") + # if os.environ.get("TP_SHARDING_DEBUG") == "1": + # print(f"detected_transformations: {detected_transformations}") + # print(f"expected_transformations: {expected_transformations}")
374-375: Remove main runner from the test moduleTest files generally shouldn’t be executable scripts. This block can cause confusion and accidental execution paths in some environments.
Apply:
-if __name__ == "__main__": - _run_pattern_detection_job(nn.Linear, False, 0, 8, False) +# Intentionally no __main__ block in test modules.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py(11 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (7)
114-114: Threading from_config through _run_job looks goodPlumbing this flag into the job runner is the right direction.
182-186: Flag name mismatch with PR objective — ensure the pipeline actually toggles the right pathPR objective mentions ad_config.use_sharding_from_config, but the tests set detect_sharding.use_sharding_from_factory. If the implementation expects use_sharding_from_config, these tests won’t exercise the intended path.
Safest fix: set both keys in the test config to cover either naming, until the codebase is fully unified.
"detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": from_config, + "use_sharding_from_factory": from_config, + "use_sharding_from_config": from_config, },Apply the same change in both optimizer configs in this file.
Also applies to: 307-311
326-327: from_config parametrization may not validate the config-driven path for toy modelsFor nn.Linear and MLP that aren’t factory-backed, from_config=True likely won’t load a model config, so the tests may either:
- silently fallback to detection (not testing the new path), or
- fail nondeterministically depending on the implementation.
Consider skipping from_config=True for models without a factory-config, or inject a sharding plan for them to truly test the config path.
335-341: LGTM on threading from_config into test_shardingThe signature and spawn wiring align with the new test matrix.
350-351: Duplicate note: from_config parametrization for pattern-detectionSame concern as above — ensure from_config=True actually consumes a config for the tested model, otherwise this branch isn’t validated.
360-365: LGTM on test_sharding_pattern_detection signature updatesThe new parameter is correctly threaded.
371-371: Routing from_config into the pattern-detection job looks correctInvocation aligns with the new signature.
|
PR_Github #15640 [ run ] triggered by Bot |
|
PR_Github #15640 [ run ] completed with state |
|
/bot run |
|
PR_Github #15769 [ run ] triggered by Bot |
|
PR_Github #15769 [ run ] completed with state |
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
177-186: head_dim detection is too narrow; use head_size or derive from hidden_size/num_attention_heads.As implemented, many configs won’t have head_dim; validate_config then invalidates the factory config, breaking the “use sharding from config” path. Prefer head_size when present (see PretrainedConfig) or compute it from hidden_size // num_attention_heads. Keep explicit overrides if provided.
Apply this refactor:
def _set_sharding_config(self, model_config: PretrainedConfig): """Set the sharding config for the model.""" - self._sharding_config["head_dim"] = 1 - if hasattr(model_config, "base_model_tp_plan"): - self._sharding_config["tp_plan"] = model_config.base_model_tp_plan - if hasattr(model_config, "head_dim"): - self._sharding_config["head_dim"] = model_config.head_dim - if hasattr(model_config, "num_hidden_layers"): - self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers + # TP plan (if provided by the model config) + if hasattr(model_config, "base_model_tp_plan"): + self._sharding_config["tp_plan"] = model_config.base_model_tp_plan + + # Determine head_dim robustly: prefer explicit head_dim/head_size, else derive. + head_dim = None + if hasattr(model_config, "head_dim"): + head_dim = getattr(model_config, "head_dim") + elif hasattr(model_config, "head_size"): + head_dim = getattr(model_config, "head_size") + elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"): + try: + head_dim = model_config.hidden_size // model_config.num_attention_heads + except Exception: + head_dim = None + self._sharding_config["head_dim"] = head_dim if head_dim is not None else 1 + + # Propagate layer count if available + if hasattr(model_config, "num_hidden_layers"): + self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
🧹 Nitpick comments (9)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
1-1: Add NVIDIA copyright header (2025).Per coding guidelines, prepend the NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (8)
1-1: Add NVIDIA copyright header (2025).Per coding guidelines, prepend the NVIDIA copyright header to all source files.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
100-121: Fix comment to reflect SplitDimension choice (avoid row/column ambiguity).The code uses SplitDimension.COLUMN (value 0) while the comment says “row_split (dim 0)”. Given the ROW/COLUMN naming inversion wrt W vs W^T, this is confusing. Clarify the comment to match the actual enum used.
Apply this small comment update:
- # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) + # --> column split on weight (SplitDimension.COLUMN / dim 0) + all_gather (dim -1 of output)
171-186: Guard against missing 'source' key in factory config to avoid KeyError.If an external ModelFactory returns a dict without 'source', this will KeyError. Use .get with UNKNOWN fallback.
Apply this guard:
- shared_config.sharding_config.predefined_config = ( - factory.get_sharding_config() if factory else {} - ) - shared_config.sharding_config.factory_source = ( - shared_config.sharding_config.predefined_config["source"] - if factory - else ShardingConfigSource.UNKNOWN - ) + shared_config.sharding_config.predefined_config = factory.get_sharding_config() if factory else {} + source = ( + shared_config.sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN) + if factory + else ShardingConfigSource.UNKNOWN + ) + shared_config.sharding_config.factory_source = source
219-223: Return type annotation mismatch (returns TransformInfo but annotated as None).Static type mismatch; also update docstring “Returns” section accordingly.
Apply:
-def detect_sharding_from_factory_config( +def detect_sharding_from_factory_config( gm: GraphModule, sharding_config: ShardingConfig, -) -> None: +) -> TransformInfo: @@ - Create sharding transformations from the predefined config. + Create sharding transformations from the predefined config and return summary info. @@ - """ + Returns: + TransformInfo: Aggregated info for detected and appended transforms. + """Also applies to: 332-337
340-359: TP detector: return type annotation mismatch (returns TransformInfo).Fix the signature to match usage in Sharding._apply.
Apply:
-def detect_column_row_shard( +def detect_column_row_shard( gm: GraphModule, sharding_config: ShardingConfig, -) -> None: +) -> TransformInfo:Also applies to: 515-518
521-598: DP BMM detector: return type mismatch and dead branch for remainder.
- Signature should return TransformInfo.
- Since you early-continue when remainder != 0, the rank<remainder branch is unreachable. Simplify the index computation for clarity.
Apply:
-def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None: +def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo: @@ - # Calculate start and end indices for this rank - if rank < remainder: - start_idx = rank * (base_size + 1) - end_idx = start_idx + base_size + 1 - else: - start_idx = remainder + rank * base_size - end_idx = start_idx + base_size + # Calculate start and end indices for this rank (remainder==0 here) + start_idx = rank * base_size + end_idx = start_idx + base_size
600-634: EP detector: return type mismatch (returns TransformInfo).Align signature with returned value and usage.
Apply:
-def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None: +def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo:
265-331: Use a helper to resolve the weight module path instead of assumingargs[1]is always aget_attr
Accessinglin_node.args[1].targetworks for pure aten.linear calls, but fused or quantized variants (and any call_module‐based Linear) may wrap or reorder the weight argument. Introduce a small utility that walks from the second arg back to its originatingget_attrnode:def _resolve_weight_target(node: Node) -> Optional[str]: # start from the weight argument candidate = node.args[1] # unwind any pack/unpack or other call_function wrappers while isinstance(candidate, Node) and candidate.op != "get_attr": # assume the true weight is the first arg of the wrapper candidate = candidate.args[0] return candidate.target if isinstance(candidate, Node) and candidate.op == "get_attr" else None # then in your loop: module_name = _resolve_weight_target(lin_node) if module_name is None: ad_logger.warning(f"Unable to resolve weight for node {lin_node.name}, skipping.") continue– Replace all direct uses of
lin_node.args[1].targetwith this helper
– This will cover pure, fused, quantized, and call_module‐traced linears uniformly
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py(4 hunks)tensorrt_llm/_torch/auto_deploy/models/hf.py(4 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py(6 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/auto_deploy/models/factory.py
- tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.pytensorrt_llm/_torch/auto_deploy/models/hf.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.pytensorrt_llm/_torch/auto_deploy/models/hf.py
🧬 Code Graph Analysis (2)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
ShardingConfigSource(16-20)register(228-233)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
filtered_nodes(209-251)is_op(183-206)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (8)
SplitDimension(183-191)validate_config(518-564)get_predefined_config(566-567)ShardingConfig(482-567)TPShardingInfo(230-265)_append_simple_shard(570-591)BMMShardingInfo(268-367)EPShardingInfo(457-479)
tensorrt_llm/_torch/auto_deploy/models/hf.py (2)
tensorrt_llm/_torch/auto_deploy/models/factory.py (3)
ModelFactory(23-221)ShardingConfigSource(16-20)model(53-55)tensorrt_llm/models/modeling_utils.py (1)
PretrainedConfig(366-567)
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)
32-32: Importing ShardingConfigSource here is correct.This aligns the factory with the sharding config source lifecycle.
97-99: Good: sharding source marked as HuggingFace.Setting the source early ensures downstream validation can trust the origin.
167-169: Good: initialize sharding config from model.config post-init.This enables the factory-driven path to be used by the transform when requested.
177-186: Convert Enum to a primitive in_sharding_configbefore serializationStoring raw
ShardingConfigSourceenum members in the_sharding_configdict will break JSON serialization (e.g.json.dumps) unless downstream code knows how to handle Enums. Replace these assignments with the enum’s.value(or.name) so only primitives are stored:• In
tensorrt_llm/_torch/auto_deploy/models/factory.py
- Line ~16:
- self._sharding_config["source"] = ShardingConfigSource.UNKNOWN
- self._sharding_config["source"] = ShardingConfigSource.UNKNOWN.value
- Line 50: same change.
• In
tensorrt_llm/_torch/auto_deploy/models/hf.py
- Line 98:
- self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
- self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE.value
This ensures
get_sharding_config()returns only JSON‐serializable primitives.Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
28-36: Imports look correct.Good to centralize ModelFactory and node/sharding utilities here.
124-129: Config surface extension LGTM.Adding use_sharding_from_factory to the transform config wires the factory-driven path cleanly.
191-197: Factory-driven early return makes sense.Skips heuristics when a valid sharding plan is supplied and requested.
|
PR_Github #16327 [ run ] triggered by Bot |
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
|
/bot run |
|
PR_Github #16328 [ run ] triggered by Bot |
|
PR_Github #16327 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
1-3: Add NVIDIA copyright header (2025) at top of file.Per repository guidelines, prepend the standard NVIDIA copyright header.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/llm_args.py (2)
168-171: Constrain values and normalize duplicates for sharding_dims.Constrain to known dimension names and dedupe while preserving order to avoid subtle misconfigurations.
Apply stricter typing:
- sharding_dims: List[str] = Field( + sharding_dims: List[Literal["tp", "ep", "dp"]] = Field( default=["tp", "ep", "dp"], - description="The dimensions to use for sharding. Allowed values: 'tp', 'ep', 'dp'.", + description="The dimensions to use for sharding. Allowed values: 'tp', 'ep', 'dp'.", )Then add this validator within AutoDeployConfig (outside the changed hunk):
@field_validator("sharding_dims") @classmethod def _normalize_sharding_dims(cls, v: List[str]) -> List[str]: # remove duplicates but preserve order seen = set() deduped = [] for d in v: if d not in seen: seen.add(d) deduped.append(d) return deduped
162-167: Clarify factory sharding behavior in llm_args.pyThe
use_sharding_from_factoryfield is already wired throughtransform/library/sharding.py(lines 195–197) to bypass detection when set and only applies if the factory supplies a non‐empty plan; otherwise it falls back to regular auto-detection. Tests intests/unittest/_torch/auto_deploy/.../test_tp_sharding.py(lines 184, 309) and the other sharding tests cover bothFalseandTruecases.– File:
tensorrt_llm/_torch/auto_deploy/llm_args.py(lines 162–167)
– Replace the Field description with:use_sharding_from_factory: bool = Field( default=False, - description="If True, use sharding from the model factory. If False, use sharding from the " - "AutoDeployConfig.", + description=( + "If True, TP/EP/DP sharding is taken from the model factory (e.g., `base_model_tp_plan`), bypassing " + "auto-detection. If the factory returns no plan (i.e., `predefined_config` is empty or `None`), " + "factory sharding is ignored and auto-detection is used. If False, sharding comes from AutoDeployConfig." + ), )tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
55-60: Detect_sharding plumbing is wired correctly and fully exercised by existing tests
- The
detect_shardingblock inconfig/default.yaml(lines 55–60) is consumed by theShardingtransform in
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py, which reads
simple_shard_only,use_sharding_from_factory, andsharding_dimsand applies them in the TP/EP/BMM stages.- Unit tests in
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/
cover each combination ofuse_sharding_from_factoryandsharding_dimsfor"tp","ep", and"bmm"(see
test_tp_sharding.py,test_ep_sharding.py,test_bmm_sharding.py).No plumbing changes are required.
Optional nitpick: add a brief inline comment above the
detect_shardingsection inconfig/default.yamldocumenting
- the allowed values for
sharding_dims("tp","ep","bmm") and- that setting
use_sharding_from_factory: truetakes precedence over the heuristic-based detection.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/llm_args.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/auto_deploy/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/auto_deploy/llm_args.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
Field(67-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/llm_args.py (2)
168-171: Constrain and normalizesharding_dims; prevent invalid values and duplicates.Current type
List[str]allows typos, duplicates, and unexpected values. Constraining the item type and adding a validator will harden the config interface and make downstream logic simpler.Minimal, backwards-compatible tightening in-place:
- sharding_dims: List[str] = Field( - default=["tp", "ep", "dp"], - description="The sharding methods to apply by the heuristic sharding stage.", - ) + sharding_dims: List[Literal["tp", "ep", "dp"]] = Field( + default=["tp", "ep", "dp"], + description=( + "Heuristic sharding dimensions to consider (subset of: 'tp', 'ep', 'dp'). " + "Ignored when use_sharding_from_factory=True." + ), + json_schema_extra={"allowed": ["tp", "ep", "dp"]}, + )Add this validator in
AutoDeployConfigto deduplicate, normalize case, and fail fast on invalid entries (place alongside existing validators):@field_validator("sharding_dims", mode="before") @classmethod def _validate_sharding_dims(cls, value: Any) -> List[str]: if value is None: # Preserve current semantics: default is all dims. return ["tp", "ep", "dp"] if not isinstance(value, (list, tuple)): raise TypeError("sharding_dims must be a list/tuple of strings") allowed = {"tp", "ep", "dp"} seen = set() out: List[str] = [] for d in value: if not isinstance(d, str): raise TypeError(f"sharding_dims entries must be strings, got: {type(d).__name__}") d_norm = d.lower() if d_norm not in allowed: raise ValueError(f"Invalid sharding dim '{d}'. Allowed: {sorted(allowed)}") if d_norm not in seen: seen.add(d_norm) out.append(d_norm) return outConsider adding unit tests that:
- Assert invalid values (e.g., ["tp", "foo"]) raise ValidationError.
- Assert duplicates (["tp", "TP", "ep"]) deduplicate to ["tp", "ep"].
- Assert flag interaction: when
use_sharding_from_factory=True,sharding_dimsis ignored by the detector.
162-166: Align PR description & clarifyuse_sharding_from_factorysemantics
- Naming consistency verified: the entire codebase (including tests and default YAML) uses
use_sharding_from_factory; no occurrences ofuse_sharding_from_configwere found. Please update the PR description to referenceuse_sharding_from_factoryto avoid confusion.- Enhance the
Fielddescription intensorrt_llm/_torch/auto_deploy/llm_args.pyto make the flag’s behavior explicit. Proposed diff:--- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -160,7 +160,11 @@ class AutoDeployConfig(BaseModel): use_sharding_from_factory: bool = Field( default=False, - description="If True, use sharding from the model factory. If False, use sharding from the AutoDeployConfig.", + description=( + "If True, bypass heuristic detection and apply sharding provided by the model " + "factory/config (e.g., base_model_tp_plan). " + "If False, run detect_sharding with sharding_dims." + ), )
- Optional: mirror this updated description on the
use_sharding_from_factoryfield intensorrt_llm/_torch/auto_deploy/transform/library/sharding.pyfor documentation consistency.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/auto_deploy/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/auto_deploy/llm_args.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
Field(67-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
|
PR_Github #16328 [ run ] completed with state |
|
/bot run |
|
PR_Github #16377 [ run ] triggered by Bot |
|
PR_Github #16377 [ run ] completed with state |
|
/bot run |
|
PR_Github #16399 [ run ] triggered by Bot |
|
PR_Github #16399 [ run ] completed with state |
|
/bot reuse-pipeline |
|
PR_Github #16478 [ reuse-pipeline ] triggered by Bot |
|
PR_Github #16478 [ reuse-pipeline ] completed with state |
…VIDIA#6972)" This reverts commit 2101d46. Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Description
If
base_model_tp_planis present in the model config andad_config.use_sharding_from_config == True, skip sharing pattern detection, and instead, apply the sharding from the config.Test Coverage
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.pyhas been updated to test new sharding logic.Summary by CodeRabbit
New Features
Refactor
Behavior
Tests
Other