-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[#4403][refactor] Move fusion, kvcache, and compile to modular inference optimizer #7057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughReworks the AutoDeploy transform pipeline: moves legacy function-style transforms to TransformRegistry-registered classes, adds MOE pattern matching, KV-cache and compile transforms, introduces an attention-backend setting, removes legacy transformation modules, and updates tests to use InferenceOptimizer-driven GraphModule transformations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant IO as InferenceOptimizer
participant Factory as ModelFactory
participant Export as Export/FX
participant PM as PatternMatcher
participant WL as LoadWeights
participant PF as PostLoadFusion
participant CI as CacheInit
participant CM as CompileModel
User->>IO: invoke optimizer (config, model/gm)
IO->>Factory: build_model(device)
IO->>Export: export_to_gm(model, args, dynamic_shapes)
IO->>PM: run pattern-matcher transforms (incl. match_moe_pattern)
IO->>WL: apply load_weights (weight_load)
IO->>PF: apply post_load_fusion transforms (fuse_collectives, fuse_allreduce_residual_rmsnorm, fuse_rmsnorm)
IO->>CI: apply cache_init transforms (update_in_out_nodes, insert_cached_attention / mla, initialize_cache, resize_kv_cache)
IO->>CM: optionally run compile_model (compile stage)
IO-->>User: return transformed GraphModule (compiled if compile_model ran)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
f7d24e8 to
e912d47
Compare
|
/bot run |
e912d47 to
9b1e879
Compare
|
/bot run |
|
PR_Github #15813 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
1-1: Missing NVIDIA copyright header (compliance).Per repo guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply a header consistent with the repo’s standard. Example (adjust to match the project’s exact template):
+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # SPDX-License-Identifier: Apache-2.0tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1)
1-1: Add NVIDIA copyright header (2025).All Python sources must include the NVIDIA copyright header. Please prepend it.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
🧹 Nitpick comments (19)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)
268-276: Avoid overshadowing the outer ‘gm’ variable in the lambda (readability).The lambda parameter named
gmshadows the outergmvariable. Rename the parameter for clarity.- lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes) + lambda g: sum(is_linear_op(n, include_quantization=True) for n in g.graph.nodes)
10-10: Use explicit package import for test helperThe
run_test_transformed_gmandcount_buffersfunctions are defined in
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py, but there are nosys.pathadjustments in the test tree to allow a bare_graph_test_helpersimport. To prevent import errors whenPYTHONPATHisn’t set, update the import to the full module path:• File:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py• Change:
- from _graph_test_helpers import count_buffers, run_test_transformed_gm + from tests.unittest._torch.auto_deploy._utils_test._graph_test_helpers import \ count_buffers, run_test_transformed_gmIf you’d rather keep the top-level import, add a
conftest.pyor an explicit
sys.path.insert(...)in your test setup to include_utils_testonsys.path.
258-267: Extract optimizer config and instantiation for clarityVerified that
InferenceOptimizer.__call__(self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None)supportsoptimizer(None, gm). To improve readability and make future extensions easier, pull the inline dict into a namedopt_configand instantiate the optimizer in its own statement:- gm = torch_export_to_gm(model, args=(x,), clone=True) - gm_transformed = InferenceOptimizer( - None, - { - "fuse_gemms": { - "stage": "post_load_fusion", - }, - }, - )(None, gm) + gm = torch_export_to_gm(model, args=(x,), clone=True) + opt_config = { + "fuse_gemms": {"stage": "post_load_fusion"}, + } + optimizer = InferenceOptimizer(None, opt_config) + gm_transformed = optimizer(None, gm)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (4)
16-33: Make DummyFactory.build_model deterministic and minimal.
- Set the model to eval() to avoid train-time behaviors (e.g., dropout) in tests.
- Use
passfor no-op methods to be explicit.class DummyFactory(ModelFactory): @@ - def build_model(self, device: str): - return self._model.to(device=device) + def build_model(self, device: str): + # Ensure deterministic inference behavior in tests. + return self._model.to(device=device).eval() @@ - def _build_model(self, device: str): - return + def _build_model(self, device: str): + pass @@ - def _load_checkpoint(self, model, device): - return + def _load_checkpoint(self, model, device): + pass
40-56: Assert shape assumptions for safer head-dim math.Add guards to catch invalid configurations early (hidden_size divisible by num_heads; grouped KV valid).
def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) # Store the head dimensions explicitly self.num_heads = args[0] # First argument is num_attention_heads self.num_kv_heads = args[2] # Third argument is num_key_value_heads - self.head_dim = args[1] // self.num_heads # hidden_size / num_heads + assert args[1] % self.num_heads == 0, "hidden_size must be divisible by num_heads" + self.head_dim = args[1] // self.num_heads # hidden_size / num_heads @@ - if self.num_heads != self.num_kv_heads: + if self.num_heads != self.num_kv_heads: + assert self.num_heads % self.num_kv_heads == 0, "num_heads must be a multiple of num_kv_heads" self.num_key_value_groups = self.num_heads // self.num_kv_heads else: self.num_key_value_groups = None
108-112: Guard test on custom SDPA op availability.Add a skip if
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpais not present to avoid opaque failures on environments lacking the custom op.@torch.inference_mode() def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): """Test the SDPA transformation with KV cache.""" # flashinfer doesn't support float32 data type if attn_backend == "flashinfer" and dtype == torch.float32: pytest.skip("flashinfer doesn't support float32 data type") + if not hasattr(torch.ops, "auto_deploy") or \ + not hasattr(torch.ops.auto_deploy, "torch_attention_bsnd_grouped_sdpa"): + pytest.skip("auto_deploy.torch_attention_bsnd_grouped_sdpa is not available in this build")
148-176: Pass dtype into CacheConfig to align cache tensor dtype with model math.This ensures KV cache tensors are allocated with the same dtype as QKV, avoiding implicit casts or kernel mismatches across backends.
- optimizer = InferenceOptimizer( - DummyFactory(model, CacheConfig()), + optimizer = InferenceOptimizer( + DummyFactory(model, CacheConfig(dtype=dtype)), { "build_model": { "stage": "factory", "device": "cuda", "run_graph_cleanup": False, "requires_clean_graph": False, }, "export_to_gm": { "stage": "export", "strict": False, "clone_state_dict": True, "run_graph_cleanup": False, "requires_clean_graph": False, }, "cleanup_input_constraints": { "stage": "post_export", }, "update_in_out_nodes": { "stage": "cache_init", }, "insert_cached_attention": { "stage": "cache_init", "attn_backend": attn_backend, }, }, ) # type: ignoretests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py (2)
1-2: Missing NVIDIA copyright header.All Python sources require the current-year NVIDIA header.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# """Tests for basic fusion of the collective."""
65-73: Potential API mismatch and brittle construction; prefer explicit config object or helper.Constructing
InferenceOptimizerwith a raw dict of transforms and invoking with(None, gm)tightly couples the test to internal optimizer signatures. If available, use an explicit AD config class or a helper method that takes a GM and returns a transformed GM to reduce churn when internals change.If a helper like
apply_post_load_fusion(gm, {"fuse_collectives": {...}})exists, prefer it. Otherwise, consider a small wrapper in test utils to abstract optimizer invocation.tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
9-11: Missing NVIDIA copyright header.Comply with repository guidelines for source headers.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# import operator from collections import defaultdict from typing import List, Tuple
121-154: Graph mutation is fine; consider minor robustness improvements.
- After in-graph transformations, BaseTransform cleanup should recompile. If any transform disables cleanup, ensure to call
gm.recompile()explicitly.- Node iteration safety: you correctly DCE and prune submodules after replacements inside
_insert_fused_gemm. Keep relying on post-cleanup to sanitize.Optionally, consider returning
info.is_clean=Trueif you know the transform preserves cleanliness and you don’t depend on shape-prop; this can skip extra canonicalization work in subsequent stages.
87-109: Quantized fusion path: preserve scale/buffer semantics consistently.The implementation ensures fused scales are registered and referenced in the fused node’s kwargs. LGTM. Consider adding a small comment that unfused linear ops are fully erasable because fused scale get_attr nodes are created explicitly to avoid accidental shared references.
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (3)
1-4: Missing NVIDIA copyright header.All Python sources need the current-year NVIDIA header.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# """Graph transform to optimize RMSNorm execution using FlashInfer."""
72-76: Docstring/backend list is slightly inconsistent with implementation.Docstring mentions only "flashinfer" or "triton", but
_BACKEND_OPSalso exposes "torch". Either remove "torch" backend support or update the docstring to reflect it.Apply:
- description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').", + description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').",
133-146: Pattern strictness and shape/dtype handling.The pattern registers for bf16/fp16/fp32 which is good. To further reduce brittleness:
- Consider ignoring the exact
meandims literal viaop_ignore_typesso[-1]vs(-1,)variants don’t break matching.- Keep
scalar_workaroundfor eps.If needed, update
register_ad_pattern(..., op_ignore_types={torch.ops.aten.mean: (int, list, tuple)})to ignore literal dims in the pattern graph.tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py (4)
1-12: Missing NVIDIA copyright header.Add the standard header.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# import operator from typing import Tuple
45-53: Avoid assertions for graph shape; treat as guards to skip non-matching nodes.Transform-time assertions can crash the pipeline on benign graphs. Replace with guard-and-continue to keep the pass resilient.
Apply:
- # check if args are as expected - assert len(node.args) == 1 and not len(node.kwargs), ( - "Unexpected args/kwargs for all_reduce" - ) + # check if args are as expected + if not (len(node.args) == 1 and not node.kwargs): + # Unexpected signature; skip + continue
45-71: Iterate over a snapshot when mutating the graph.You erase nodes during iteration. Snapshot the node list to avoid iterator invalidation in future FX versions.
Apply:
- for node in gm.graph.nodes: + for node in list(gm.graph.nodes): if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): continue
110-194: Minor consistency and cleanup improvements in RMSNorm fusion path.
- Mixed use of
graphandgm.graphis confusing; use the localgraphconsistently.- After rewiring, consider running DCE explicitly (BaseTransform cleanup will also handle it, but being explicit can help during debugging).
Apply:
- final_output_node = gm.graph.create_node( + final_output_node = graph.create_node( "call_function", target=operator.getitem, args=(fused_node, 0), ) - add_output_node = gm.graph.create_node( + add_output_node = graph.create_node( "call_function", target=operator.getitem, args=(fused_node, 1), )Optionally, after the node rewrites:
graph.eliminate_dead_code()
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (18)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/interface.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/transform.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py(4 hunks)
💤 Files with no reviewable changes (5)
- tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
- tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
- tensorrt_llm/_torch/auto_deploy/transform/interface.py
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py
- tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
- tensorrt_llm/_torch/auto_deploy/transformations/transform.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.pytensorrt_llm/_torch/auto_deploy/transform/library/collectives.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.pytensorrt_llm/_torch/auto_deploy/transform/library/collectives.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.pytests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py
🧬 Code Graph Analysis (6)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (3)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2)
ADPatternMatcherPass(59-65)register_ad_pattern(97-180)tensorrt_llm/_torch/auto_deploy/transform/interface.py (9)
BaseTransform(139-376)SharedConfig(51-57)TransformConfig(60-99)TransformInfo(108-133)TransformRegistry(379-407)register(385-392)get_config_class(161-166)get_config_class(400-402)_apply(366-376)tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
torch_rmsnorm(64-76)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory(15-207)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface(12-70)tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
BaseTransform(139-376)SharedConfig(51-57)TransformInfo(108-133)TransformRegistry(379-407)register(385-392)_apply(366-376)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_test_transformed_gm(68-138)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py (5)
tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py (2)
is_trtllm_op_available(58-60)fused_allreduce_residual_rmsnorm(27-38)tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory(15-207)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface(12-70)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
get_op_overload_packet(173-180)get_user_if_pattern_match(280-289)is_op(183-206)tensorrt_llm/_torch/auto_deploy/transform/interface.py (5)
BaseTransform(139-376)SharedConfig(51-57)TransformInfo(108-133)TransformRegistry(379-407)register(385-392)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (5)
SequenceEmbeddingInfo(41-53)build_model(25-26)_build_model(28-29)_load_checkpoint(31-32)get_cache_config(34-35)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
CacheConfig(25-28)to(336-340)tensorrt_llm/_torch/auto_deploy/models/factory.py (6)
ModelFactory(15-207)model(43-45)build_model(52-88)_build_model(91-93)_load_checkpoint(199-207)get_cache_config(99-105)tensorrt_llm/_torch/auto_deploy/shim/interface.py (2)
CachedSequenceInterface(12-70)to(33-37)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_test_transformed_gm(68-138)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
🔇 Additional comments (13)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
14-15: No import changes required – exports are correctly configured.
torch_export_to_gmis publicly re-exported in
tensorrt_llm/_torch/auto_deploy/export/__init__.pyviafrom .export import *.InferenceOptimizeris imported directly from its defining module
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py.All imports resolve successfully; no adjustments needed.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (5)
5-5: LGTM: imports SequenceEmbeddingInfo from test helpers.
9-9: LGTM: CacheConfig import from attention_interface.
94-96: LGTM: backend parametrization matches the new config surface.
11-11: InferenceOptimizer import path is correctThe repository defines two InferenceOptimizer classes:
- tensorrt_llm/_torch/auto_deploy/transform/optimizer.py accepts a plain dict (InferenceOptimizerConfig)
- tensorrt_llm/_torch/auto_deploy/transformations/transform.py expects an AutoDeployConfig
Since test_kv_cache.py passes a config dict, importing from transform.optimizer is intentional and will not cause an import error. No change required.
148-176: Per-transform and globalattn_backendare already propagated correctly
- In
transform.py(lines 58–61),InferenceOptimizersets each transform’s.attn_backendfrom the top-level config:if "insert_cached_attention" in self.ad_config.transforms: self.ad_config.transforms["insert_cached_attention"].attn_backend = self.ad_config.attn_backend- Immediately after creating the modular optimizer, it also assigns the shared config’s backend (lines 83–84):
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend- The
InsertCachedAttentiontransform then picks up its ownattn_backend(falling back onshared_config.attn_backendwhen its local config isNone).No further changes are needed here.
tensorrt_llm/_torch/auto_deploy/config/default.yaml (4)
71-88: Post-load fusion set looks sane; ensure backend option aligns with RMSNorm transform config.
fuse_rmsnorm.backend: flashinfermatches the new transform’s config. No issues on content; just ensure the string values are validated centrally (e.g., raise on invalid backend).
89-97: insert_cached_mla_attention attn_backend is correctThe
attn_backend: MultiHeadLatentAttentionin default.yaml matches the registry key registered incustom_ops/mla.pyand corresponds toad_config.mla_backendin your transforms. No override by the optimizer occurs beyond this explicit assignment, so no changes are necessary.
22-41: All pattern_matcher transforms are registeredI ran a registry check against
tensorrt_llm/_torch/auto_deploy/config/default.yamland confirmed that everymatch_*key under thepattern_matcherstage is present inTransformRegistry.register(...). No missing transforms were found.
105-106: Compile stage verified: Thecompile_modeltransform is registered and returns annn.Module(includingCapturedGraphfor CUDA‐graph backends). Downstream code and existing tests already handleCapturedGraph, so no further changes are required.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py (2)
76-83: Good validation of fused op presence and original op absence.The
check_transformed_graphpredicate correctly asserts the intended fusion.
16-18: InferenceOptimizer usage is correct – no changes needed.Verified that:
InferenceOptimizer.__call__signature is__call__(self, cm, gm=None).- Calling
InferenceOptimizer(...)(None, gm)passescm=Noneandgm, which matches the signature.- All existing tests invoke it with
Noneforcm, and transforms handle aNonecmas intended.tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py (1)
96-107: Graceful skip when TRT-LLM op not available.Nice: returns a clean/valid info when TRT-LLM fused op is unavailable. LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (1)
1-1: Add NVIDIA copyright header (required by repo guidelines).Please prepend the standard header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
1-1: Add NVIDIA copyright header (required by repo guidelines).Please prepend the standard header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
1-1: Add NVIDIA copyright header (required by repo guidelines).Please prepend the standard header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1)
1-1: Missing NVIDIA copyright header (required by repo guidelines).Tests are also subject to the header requirement. CI/license checks may fail without it.
Apply this diff to prepend the header:
+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (3)
48-66: Defensive checks around transform configs.These assignments assume
self.ad_config.transforms[...]entries are objects with the listed attributes. If a dict slips through (e.g., from YAML), this will raise. Consider defensive guards or conversion.- if "load_weights" in self.ad_config.transforms: - self.ad_config.transforms[ - "load_weights" - ].checkpoint_device = self.ad_config.checkpoint_device - self.ad_config.transforms["load_weights"].device = cm.device + if "load_weights" in self.ad_config.transforms: + lw = self.ad_config.transforms["load_weights"] + # Allow both object- and dict-style configs + try: + lw.checkpoint_device = self.ad_config.checkpoint_device + lw.device = cm.device + except AttributeError: + lw["checkpoint_device"] = self.ad_config.checkpoint_device + lw["device"] = getattr(cm, "device", None)
52-52: Guard against None cm when accessing device.If
cmis ever None (e.g., misused externally),cm.devicewill raise. A fast-fail assertion improves diagnostics.def __call__(self, cm: CachedSequenceInterface) -> nn.Module: + assert cm is not None, "CachedSequenceInterface (cm) must be provided"
82-87: Naming clarity: avoid duplicate “InferenceOptimizer” symbols across packages.This class delegates to
transform.optimizer.InferenceOptimizerunder the aliasModularInferenceOptimizer. The duplication of the public name “InferenceOptimizer” in two modules is easy to confuse. Consider renaming this facade (or the alias) to “HighLevelInferenceOptimizer” or “InferenceOrchestrator” to reduce ambiguity.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (4)
16-34: DummyFactory is fine for tests; consider light tightening (types).Works as a minimal stub. For readability and static checkers, consider adding type hints to the constructor.
Suggested tweak:
-class DummyFactory(ModelFactory): +class DummyFactory(ModelFactory): """Dummy factory to pass cache_config for testing.""" - def __init__(self, model, cache_config): + def __init__(self, model: torch.nn.Module, cache_config: CacheConfig): self._model = model self.cache_config = cache_config
36-85: Prefer reshape over view; remove unused attribute.
- Linear outputs are typically contiguous, but reshape is safer than view when contiguity assumptions change.
- num_key_value_groups is computed but never used.
Apply:
- if self.num_heads != self.num_kv_heads: - self.num_key_value_groups = self.num_heads // self.num_kv_heads - else: - self.num_key_value_groups = None + # Number of key-value groups (unused here; omit unless needed) # Reshape to [b, s, n, h_d] - q = q.view(b, s, self.num_heads, self.head_dim) - k = k.view(b, s, self.num_kv_heads, self.head_dim) - v = v.view(b, s, self.num_kv_heads, self.head_dim) + q = q.reshape(b, s, self.num_heads, self.head_dim) + k = k.reshape(b, s, self.num_kv_heads, self.head_dim) + v = v.reshape(b, s, self.num_kv_heads, self.head_dim)
149-176: Leverage InitializeCache stage to avoid manual cache init in test.Since kvcache.py provides InitializeCache, you can include it in the pipeline and drop the manual cm.initialize_caches() call, keeping tests aligned with the modular optimizer flow.
Apply this diff to extend the transform config:
"insert_cached_attention": { "stage": "cache_init", "attn_backend": attn_backend, }, + "initialize_cache": { + "stage": "cache_init", + },Then remove the explicit initialization (see a companion suggestion near Line 180).
179-181: Avoid redundant .to('cuda') and prefer pipeline-driven cache init.
- gm is already created on CUDA via the "build_model" stage’s device="cuda".
- If you adopt the InitializeCache stage, you can eliminate manual cm.initialize_caches().
Apply:
- gm.to("cuda") - cm.initialize_caches() + # Cache initialization handled by InitializeCache stage
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (18)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/interface.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/transform.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py(4 hunks)
💤 Files with no reviewable changes (5)
- tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py
🚧 Files skipped from review as they are similar to previous changes (7)
- tensorrt_llm/_torch/auto_deploy/transform/interface.py
- tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
- tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
- tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py
- tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
- tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.pytensorrt_llm/_torch/auto_deploy/transformations/transform.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.pytensorrt_llm/_torch/auto_deploy/transformations/transform.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
🧠 Learnings (1)
📚 Learning: 2025-08-14T23:23:27.420Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.420Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Applied to files:
tensorrt_llm/_torch/auto_deploy/config/default.yaml
🧬 Code Graph Analysis (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (2)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_test_transformed_gm(68-138)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (2)
count_buffers(62-65)run_test_transformed_gm(68-138)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (5)
SequenceEmbeddingInfo(41-53)build_model(25-26)_build_model(28-29)_load_checkpoint(31-32)get_cache_config(34-35)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
CacheConfig(25-28)to(336-340)tensorrt_llm/_torch/auto_deploy/models/factory.py (6)
ModelFactory(15-207)model(43-45)build_model(52-88)_build_model(91-93)_load_checkpoint(199-207)get_cache_config(99-105)tensorrt_llm/_torch/auto_deploy/shim/interface.py (2)
CachedSequenceInterface(12-70)to(33-37)tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
InferenceOptimizer(15-104)
🔇 Additional comments (20)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (3)
7-8: LGTM: Switch to GM export and modular optimizer import.Using
torch_export_to_gmand the modularInferenceOptimizeraligns this test with the new GM-based fusion flow.
67-74: LGTM: Using shared helper to validate transformed GM behavior.Passing
dynamic_shapesand the parameter-count identity function matches the helper’s expected contract.
54-66: InferenceOptimizer API and “fuse_rmsnorm” registration validated
InferenceOptimizer.__init__accepts(factory: ModelFactory, config: InferenceOptimizerConfig)InferenceOptimizer.__call__signature is(cm: CachedSequenceInterface, gm: Optional[GraphModule] = None) → GraphModule@TransformRegistry.register("fuse_rmsnorm")is present intensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pyNo changes required.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (3)
80-83: LGTM: Validation inspects fused op presence on the transformed graph.Iterating
gm_transformed.graph.nodesand checkingtorch.ops.dist.fused_allreduce_residual_rmsnormis the correct assertion.
66-73: Confirmed InferenceOptimizer signature and transform registration
- InferenceOptimizer in tensorrt_llm/_torch/auto_deploy/transform/optimizer.py defines
so callingdef __call__(self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None) -> GraphModuleInferenceOptimizer(None, {...})(None, gm)is valid.- The
fuse_allreduce_residual_rmsnormtransform is registered via
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
in tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py.No changes needed.
94-96: GraphModule inputs are supported by both export pathsBoth
torch.export.exportand ourtorch_export_to_gmhelper accept any subclass ofnn.Module—includingfx.GraphModule. In particular:
torch_export_to_gm(model: nn.Module, …)has no strict type checks beyondisinstance(egm, fx.GraphModule)on the output.- The test itself imports and successfully calls
export(gm_transformed, args=args)
which demonstrates thattorch.export.exporthandles aGraphModulejust like any othernn.Module.No changes are needed here.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)
268-278: LGTM: Correctly validates GEMM count post-fusion and output consistency.
is_linear_op(..., include_quantization=True)with thenum_gemms_after_fusionpredicate plus relaxed tolerances for fp8 looks good.
258-267: Confirmed InferenceOptimizer signature and fuse_gemms registration
TheInferenceOptimizerclass implements__call__(self, cm, gm)and afuse_gemmstransform is registered via@TransformRegistry.register("fuse_gemms"). The test’s use ofInferenceOptimizer(None, {...})(None, gm)is valid.
10-10: Ensure helper module is importable in the test environment
The file_graph_test_helpers.pylives under:
• tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
but in
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (line 10)
you’re doing:from _graph_test_helpers import count_buffers, run_test_transformed_gmThis will fail unless the
_utils_testdirectory is onsys.path(e.g., via aconftest.pyor explicitsys.path.append).Please verify in your test runner that
_utils_testis added toPYTHONPATH, or update the import to:from _utils_test._graph_test_helpers import count_buffers, run_test_transformed_gmand confirm that all tests still pass.
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
102-105: LGTM: Single-pass modular optimizer with cleanup.Returning the transformed module (without forcing compile-and-capture here) aligns with the PR’s modularization goal.
tensorrt_llm/_torch/auto_deploy/config/default.yaml (4)
22-25: LGTM: Clear sectioning for pattern-matcher transforms.The documentation-style headers improve readability of the transform pipeline.
39-41: LGTM: Separation between standardization and optimization stages.Keeping “standardize graph” vs “transformations” distinct reflects a clean pipeline.
66-70: LGTM: Explicit weight-load stage ahead of post-load fusion.This keeps load semantics separated from graph rewrites.
89-106: LGTM: Cache-init and compile stages are cleanly delineated.Having
update_in_out_nodes, cache insertion, initialization, and resizing grouped before an explicitcompile_modelstage matches the modular optimizer changes.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (6)
5-5: Import switch to SequenceEmbeddingInfo is appropriate.Using SequenceEmbeddingInfo aligns the inputs with embedding-shaped tensors required by SDPA path. LGTM.
94-96: Backend parametrization reads cleanly.Lowercase "triton"/"flashinfer" matches the shared config default pattern. LGTM.
108-108: Good use of inference_mode for a deterministic test path.This prevents autograd overhead and aligns with inference semantics.
111-113: Appropriate skip for flashinfer + float32.Matches current backend limitations to avoid spurious failures.
223-224: Exportability check is a solid guard.Ensures the transformed model remains exportable with cache args wired in. LGTM.
9-13: ✅ Verified correct InferenceOptimizer import and attn_backend handlingThe test is importing
InferenceOptimizerfrom
tensorrt_llm._torch.auto_deploy.transform.optimizer(which expects anInferenceOptimizerConfig), not the legacy one undertransformations/transform.py.
The passed transforms dict—including the"insert_cached_attention": { ..., "attn_backend": attn_backend }entry—maps correctly toInsertCachedAttentionConfig.attn_backendand is consumed by the transform as intended. No mismatches found.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
Show resolved
Hide resolved
|
PR_Github #15813 [ run ] completed with state |
…inference optimizer (#126) * move fusions Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * move rms_norm fusion Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * move rms_norm fusion Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * move kvcache Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * move compile_model Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * address review comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> * fix: correct rms_norm backend Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
9b1e879 to
52a964c
Compare
|
/bot run |
|
PR_Github #15951 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)
81-85: Set requires_grad=False for fused parameter (non-quantized path) and reuse existing param list.In the non-quantized branch you recreate the parameter list and leave requires_grad=True by default. For inference-only graphs this should be False to avoid unnecessary autograd state. Also reuse params_unfused to avoid a second lookup.
- elif all(q is None for q in quantization_impls): - param_fused = nn.Parameter(fuse_weights([gm.get_parameter(k) for k in keys_unfused])) + elif all(q is None for q in quantization_impls): + param_fused = nn.Parameter(fuse_weights(params_unfused), requires_grad=False)
♻️ Duplicate comments (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)
3-3: Fix helper import path so tests can discover run_test_transformed_gm.The helper lives under tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py. Importing
_graph_test_helpersdirectly will likely raise ImportError unless sys.path is tweaked elsewhere. Make_utils_testa package and import from it, or adjust sys.path.-from _graph_test_helpers import run_test_transformed_gm +from _utils_test._graph_test_helpers import run_test_transformed_gmAlternatively, add a shared conftest that appends
_utils_testto sys.path.tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
119-124: Avoid requiring CUDA for pattern registration. Use CPU dummies.Pattern tracing doesn’t require CUDA; allocating CUDA tensors here prevents running in CPU-only CI or dev environments.
- def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): - return [ - torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), - torch.randn(hidden_size, device="cuda", dtype=weight_dtype), - eps, - ] + def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): + return [ + torch.randn(bs, hidden_size, device="cpu", dtype=input_dtype), + torch.randn(hidden_size, device="cpu", dtype=weight_dtype), + eps, + ]
🧹 Nitpick comments (12)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
57-57: Constrain attn_backend to known values (Literal/Enum) for safer config validation.Using a plain str risks typos propagating into downstream transforms. Recommend narrowing the type to Literal or an Enum and updating the docstring accordingly.
Apply within this file:
- attn_backend: str = Field(default="flashinfer", description="The attention backend to use.") + attn_backend: "Literal['flashinfer', 'triton', 'torch']" = Field( + default="flashinfer", + description="The attention backend to use. One of: 'flashinfer', 'triton', 'torch'.", + )Add this import (outside the shown range):
from typing import Literal
1-1: Add NVIDIA copyright header (repo guideline).Source files must carry the current-year NVIDIA header. Please prepend it at the very top of the file (before the module docstring).
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """The interface for all transforms.tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (4)
92-101: Ensure complete cleanup and recompile after graph surgery.You call eliminate_dead_code and delete_all_unused_submodules, which is good. Post-fusion, correctness relies on a recompile step; this is currently handled by post-cleanup in BaseTransform (since is_clean=False). If anyone flips run_graph_cleanup or is_clean later, this could regress. Consider explicitly calling gm.recompile() here to localize correctness to this transform.
Example:
gm.graph.eliminate_dead_code() gm.delete_all_unused_submodules() + gm.recompile()Also applies to: 101-116
148-149: Guard CUDA-specific cache flush.torch.cuda.empty_cache() is harmless on CUDA builds but can error in CPU-only environments or test stubs. Guard with is_available() to be safe.
- torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache()
121-153: Minor: Unused parameters in _apply signature.cm, factory, and shared_config are currently unused. It’s fine to keep for interface uniformity, but prefix with underscores or add a short comment to appease linters.
1-1: Add NVIDIA copyright header (repo guideline).This is a production source file and should include the standard header.
+// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import operatortests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)
76-79: Redundant numerical equivalence assert; helper already checks outputs.run_test_transformed_gm compares outputs (unless skip_output_assert=True). The extra assert here is duplicate and can be removed to keep the test focused.
- new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16) - y_transformed = gm_transformed(new_input) - y_model = model(new_input) - torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3) + # Output parity is already verified by run_test_transformed_gmtensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (3)
1-1: Update docstring to reflect multi-backend support and add NVIDIA header.Docstring mentions FlashInfer only, but the transform supports "flashinfer", "triton", and "torch". Also add the standard header required for source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + -"""Graph transform to optimize RMSNorm execution using FlashInfer.""" +"""Graph transform to optimize RMSNorm execution with pluggable backends. + +Supports backends: "flashinfer", "triton", and "torch". +"""
72-75: Config doc should list all supported backends.The description omits "torch". Align with _BACKEND_OPS to avoid confusion.
- description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').", + description="Backend to use for RMSNorm computation ('flashinfer', 'triton', 'torch').",
100-112: Validate backend once and explain failure mode.The early ValueError is good. Consider including the provided backend value and available keys in the message (already done). Optionally, offer the SharedConfig.attn_backend as a fallback when config.backend is unset to reduce config duplication.
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)
82-87: Share attn_backend via shared_config: good; consider centralizing SharedConfig population.Today you set shared_config.attn_backend after constructing the optimizer. Long-term, consider building SharedConfig up-front (once) and passing it into the optimizer ctor to avoid accidental divergence when more shared fields are added.
1-1: Add NVIDIA copyright header (repo guideline).Please prepend the standard header at the very top.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """High-level entrypoint to transform a model into an efficient inference model."""
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (18)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/interface.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py(3 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(1 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py(0 hunks)tensorrt_llm/_torch/auto_deploy/transformations/transform.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py(4 hunks)
💤 Files with no reviewable changes (5)
- tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
- tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (8)
- tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
- tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
- tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
- tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py
- tensorrt_llm/_torch/auto_deploy/config/default.yaml
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py
- tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/auto_deploy/transform/interface.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/fusion.pytensorrt_llm/_torch/auto_deploy/transformations/transform.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/auto_deploy/transform/interface.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/fusion.pytensorrt_llm/_torch/auto_deploy/transformations/transform.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
📚 Learning: 2025-08-12T10:28:57.320Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-12T10:28:57.320Z
Learning: Applies to **/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py} : Prepend NVIDIA copyright header (current year) to all source files
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧬 Code Graph Analysis (4)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
Field(67-94)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (4)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface(12-70)tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2)
ADPatternMatcherPass(59-65)register_ad_pattern(97-180)tensorrt_llm/_torch/auto_deploy/transform/interface.py (8)
BaseTransform(139-376)SharedConfig(51-57)TransformConfig(60-99)TransformInfo(108-133)TransformRegistry(379-407)register(385-392)get_config_class(161-166)get_config_class(400-402)tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
torch_rmsnorm(64-76)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory(15-207)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface(12-70)tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
BaseTransform(139-376)SharedConfig(51-57)TransformInfo(108-133)TransformRegistry(379-407)register(385-392)_apply(366-376)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_test_transformed_gm(68-138)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(198-284)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
290-293: Good: Skip metadata update for non-FX GraphModule outputs.The isinstance(gm, GraphModule) guard prevents attribute access on compiled CapturedGraph objects. This aligns with the new compile stage behavior.
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (3)
48-66: Propagate cm.device into load_weights: good. Also wire backends from ad_config.This keeps device selection and backend routing single-sourced. Looks correct.
74-81: Compile settings plumbed correctly into transform config.Passing cuda_graph_batch_sizes and compile_backend via transforms is consistent with the modular optimizer.
102-105: Return egm after cleanup: LGTM.Cache cleanup before returning the optimized module is fine.
|
PR_Github #15951 [ run ] completed with state |
Original PR (nv-auto-deploy#126)
Summary by CodeRabbit
New Features
Refactor
Documentation
Tests
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.