KEMBAR78
[None][fix] enable NvFP4/FP8 quantization for Nemotron-H architecture by tomeras91 · Pull Request #7589 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@tomeras91
Copy link
Collaborator

@tomeras91 tomeras91 commented Sep 7, 2025

Summary by CodeRabbit

  • New Features

    • Added a configuration option to defer creating certain layer weights during initialization, reducing startup memory and enabling faster model setup (applies to Mamba2-based models).
  • Bug Fixes

    • Ensured all “scale” parameters are included during weight import.
    • Made module-exclusion patterns more flexible to correctly target intended modules.
    • Improved quantization exclusion logic to correctly handle fused projection layers by mapping them to their underlying components.

Description

This PR ensures support for NvFP4 / FP8 quantization for the Nemotron-H architecture, specifically, Nemotron-Nano-v2. The main contribution here is to allow quantization of specific Linear modules and not necessarily all of them. It also includes a general fix for ModelOpt quantized models that didn't quantize packed modules.

main changes:

  • make sure packed modules (qkv weights , gate_up_proj for gated MLP) are dealt with correctly when parsing ModelOpt's exclude_modules. This is a fix relevant for all ModelOpt quantized models that didn't quantize these modules.
  • pass skip_create_weights_in_init from the config to all mamba2 Linear layers
  • make Nemotron-H weight loading more robust to different checkpoint formats.

Test Coverage

I validated that internal Nemotron-Nano-v2 NVFP4 / FP8 checkpoints, some not quantizing all Linear layers, can be loaded and inferenced correctly.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This reverts commit de2c6e4.

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…es don't have to be 0-dimension tensors

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…llow for multiple formats

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…ar layers

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
… modules from quantization

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
@tomeras91 tomeras91 requested review from a team as code owners September 7, 2025 10:16
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

📝 Walkthrough

Walkthrough

Adjusts Nemotron-H weight preprocessing to copy any key containing "_scale" regardless of tensor dimensionality; normalizes module exclude names with regex in NemotronH; remaps names for fused Linear variants during quant-exclusion; and adds a skip_create_weights_in_init flag passed from ModelConfig into Mamba2Mixer Linear initializers.

Changes

Cohort / File(s) Summary
Nemotron-H HF weight mapping
tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py
Keys containing "_scale" are now copied regardless of tensor dimensionality in preprocess_weights; other mapping branches unchanged.
Nemotron-H model path normalization
tensorrt_llm/_torch/models/modeling_nemotron_h.py
Replaced direct string replace with re.sub(r'(model\.layers\.)?backbone','model', k) to normalize exclude module names (handles optional model.layers. prefix).
Quant exclusion for fused Linear
tensorrt_llm/_torch/models/modeling_utils.py
When evaluating exclude_modules for Linear modules with fused weight modes, generate candidate names by remapping fused projection names (e.g., gate_up_projgate_proj/up_proj, qkv_projq_proj/k_proj/v_proj) and treat any match as excluded before applying quant_config.
Deferred Linear weight creation in Mamba2
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py, tensorrt_llm/_torch/modules/linear.py, .../ModelConfig
Linear.__init__ gains skip_create_weights_in_init kwarg; Mamba2Mixer passes config.skip_create_weights_in_init to its in_proj/out_proj (and related) Linear initializations to defer weight creation per config.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant MU as modeling_utils.apply_quant_config_exclude_modules
  participant M as Module
  participant Q as QuantConfig

  Caller->>MU: apply_quant_config_exclude_modules(modules, exclude_list)
  loop each module
    MU->>M: inspect type, name, weight_mode
    alt Linear & FUSED_GATE_UP_LINEAR
      MU->>MU: candidates = [name, name with gate_up_proj→gate_proj, gate_up_proj→up_proj]
    else Linear & FUSED_QKV_LINEAR
      MU->>MU: candidates = [name, qkv_proj→q_proj, →k_proj, →v_proj]
    else
      MU->>MU: candidates = [name]
    end
    MU->>MU: is_excluded = any(match(c, exclude_list) for c in candidates)
    alt is_excluded and M.quant_config
      MU->>M: set M.quant_config = new Q(...)
    end
  end
Loading
sequenceDiagram
  autonumber
  participant User
  participant MX as Mamba2Mixer.__init__
  participant CFG as ModelConfig
  participant L as Linear.__init__

  User->>MX: new Mamba2Mixer(config=CFG)
  MX->>CFG: read skip_create_weights_in_init
  MX->>L: in_proj(..., skip_create_weights_in_init=CFG.value)
  L-->>MX: returns (weights may be deferred)
  MX->>L: out_proj(..., skip_create_weights_in_init=CFG.value)
  L-->>MX: returns
  MX-->>User: Mamba2Mixer constructed
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • Wanli-Jiang
  • omera-nv
  • danielafrimi
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (8)
tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py (2)

37-39: TP-safety for “_scale” tensors

Blindly copying scales may mis-shard when scales are per-output-channel and need TP splitting (similar to A/D/dt_bias). Suggest splitting on dim 0 when divisible.

-            if "_scale" in key:
-                new_weights[key] = weights[name]
+            if "_scale" in key:
+                w = weights[name]
+                # Respect TP sharding for per-channel scales
+                if w.ndim > 0 and (w.shape[0] % tp_size == 0):
+                    w = split(w, tp_size, tp_rank)
+                new_weights[key] = w

1-1: Add NVIDIA header (2025)

This file is missing the required SPDX header with current year.

Add to the top of the file (non-diff snippet):

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/models/modeling_nemotron_h.py (2)

16-16: Make exclude_modules normalization token-safe

Current regex may replace any substring “backbone” inside a longer token. Limit to whole-token to avoid accidental renames.

-                re.sub(r'(model\.layers\.)?backbone', 'model', K)
+                re.sub(r'(?<!\w)backbone(?!\w)', 'model', k)

If you specifically intend only top-level “backbone” or “model.layers.backbone”, consider asserting boundaries via path tokenization instead.

Also applies to: 259-261


1-15: Update header year to 2025

Header shows 2022–2024; update to 2022–2025 per guidelines.

-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
tensorrt_llm/_torch/models/modeling_utils.py (2)

485-500: Broaden fused-module exclusion and collapse logic

Cover k/v synonyms for fused QKV and up_proj synonym for fused gate/up; reduce false negatives when exclude lists use different base names. Keep behavior identical for non-Linear modules.

-                    if isinstance(module, Linear):
-                        weight_mode = module.weights_loading_config.weight_mode
-                        if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
-                            # gate_proj and up_proj share the same exclusion rule
-                            is_excluded = quant_config.is_module_excluded_from_quantization(
-                                name.replace('gate_up_proj', 'gate_proj'))
-                        elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
-                            # q_proj, k_proj and v_proj share the same exclusion rule
-                            is_excluded = quant_config.is_module_excluded_from_quantization(
-                                name.replace('qkv', 'q'))
-                        else:
-                            is_excluded = quant_config.is_module_excluded_from_quantization(
-                                name)
+                    if isinstance(module, Linear):
+                        weight_mode = module.weights_loading_config.weight_mode
+                        candidates = [name]
+                        if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
+                            # gate_proj and up_proj share the same exclusion rule
+                            candidates += [
+                                name.replace('gate_up_proj', 'gate_proj'),
+                                name.replace('gate_up_proj', 'up_proj'),
+                            ]
+                        elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
+                            # q_proj, k_proj and v_proj share the same exclusion rule
+                            q_name = name.replace('qkv', 'q')
+                            candidates += [
+                                q_name,
+                                q_name.replace('.q_proj', '.k_proj'),
+                                q_name.replace('.q_proj', '.v_proj'),
+                            ]
+                        is_excluded = any(
+                            quant_config.is_module_excluded_from_quantization(n)
+                            for n in candidates
+                        )
                     else:
                         is_excluded = quant_config.is_module_excluded_from_quantization(
                             name)

1-1: Add NVIDIA header (2025)

Missing required SPDX header.

Add at top (non-diff snippet):

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (2)

92-102: Defensive default for skip_create_weights_in_init

Accessing config.skip_create_weights_in_init directly can break older configs. Use getattr with False default.

         self.in_proj = Linear(
             d_model,
             d_in_proj,
             bias=bias,
             dtype=dtype,
             mapping=self.mapping,
             tensor_parallel_mode=TensorParallelMode.COLUMN,
             quant_config=config.get_quant_config(),
-            skip_create_weights_in_init=config.skip_create_weights_in_init,
+            skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),
             allreduce_strategy=config.allreduce_strategy)
...
         self.out_proj = Linear(
             d_inner,
             d_model,
             bias=bias,
             dtype=dtype,
             mapping=self.mapping,
             tensor_parallel_mode=TensorParallelMode.ROW,
             quant_config=config.get_quant_config(),
-            skip_create_weights_in_init=config.skip_create_weights_in_init,
+            skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),
             allreduce_strategy=config.allreduce_strategy)

Also apply the same to conv1d init for consistency:

# conv1d ...
skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),

Also applies to: 143-153


1-15: Update header year to 2025

Adjust to 2022–2025.

-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aea8ac1 and c6f896f.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17924 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17924 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13431 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17926 [ run ] triggered by Bot

Copy link
Collaborator

@danielafrimi danielafrimi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17926 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13433 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17938 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17938 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13443 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17947 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17947 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13450 completed with status: 'FAILURE'

… are packed in the checkpoint as well

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
@tomeras91
Copy link
Collaborator Author

/bot run

@tomeras91
Copy link
Collaborator Author

@coderabbitai help

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

CodeRabbit commands

These commands are invoked using PR/Issue comments.

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve to resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, support, documentation and community

  • Visit our status page to check the current availability of CodeRabbit.
  • Create a ticket on our support page for assistance with any issues or questions.
  • Visit our documentation site for detailed information on how to use CodeRabbit.
  • Join our Discord community to connect with other users and get help from the community.
  • Follow us on X/Twitter for updates and announcements.

@tomeras91
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17953 [ run ] triggered by Bot

@tomeras91
Copy link
Collaborator Author

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

✅ Actions performed

Full review triggered.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_nemotron_h.py (1)

259-261: Tighten the regex to avoid accidental substring replacements

Current pattern can replace "backbone" inside longer tokens (e.g., "backbone_extra"). Add word boundaries to target path segments only.

-                re.sub(r'(model\.layers\.)?backbone', 'model', k)
+                re.sub(r'(?<!\w)(?:model\.layers\.)?backbone(?!\w)', 'model', k)
tensorrt_llm/_torch/models/modeling_utils.py (1)

485-505: Avoid duplicate candidate names and guard replacements by substring

Minor cleanup: using a set prevents duplicates; gating replacements on substring presence avoids no-op strings.

-                    candidates = [name]
+                    candidates = {name}
                     if isinstance(module, Linear):
                         weight_mode = module.weights_loading_config.weight_mode
                         if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
                             # sometimes gate and up proj are not packed in the checkpoint,
                             # but they still share the same exclusion rule
-                            candidates += [
-                                name.replace('gate_up_proj', 'gate_proj'),
-                                name.replace('gate_up_proj', 'up_proj')
-                            ]
+                            if 'gate_up_proj' in name:
+                                candidates.update({
+                                    name.replace('gate_up_proj', 'gate_proj'),
+                                    name.replace('gate_up_proj', 'up_proj'),
+                                })
                         elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
                             # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
                             # but they still share the same exclusion rule
-                            candidates += [
-                                name.replace('qkv_proj', 'q_proj'),
-                                name.replace('qkv_proj', 'k_proj'),
-                                name.replace('qkv_proj', 'v_proj')
-                            ]
-                    is_excluded = any(
+                            if 'qkv_proj' in name:
+                                candidates.update({
+                                    name.replace('qkv_proj', 'q_proj'),
+                                    name.replace('qkv_proj', 'k_proj'),
+                                    name.replace('qkv_proj', 'v_proj'),
+                                })
+                    is_excluded = any(
-                        quant_config.is_module_excluded_from_quantization(n)
-                        for n in candidates)
+                        quant_config.is_module_excluded_from_quantization(n)
+                        for n in candidates)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (2)

143-152: Mirror the getattr fallback here; consider aligning conv1d initializer too.

Make out_proj resilient the same way; and (optional) apply the same change to self.conv1d for consistency.

Apply this diff in-place:

-            skip_create_weights_in_init=config.skip_create_weights_in_init,
+            skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),

Optional (outside this hunk) for conv1d:

# in self.conv1d = Linear(...):
skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),

92-101: Add backward-compatible fallback for skip_create_weights_in_init

Linear.init already accepts skip_create_weights_in_init and ModelConfig sets skip_create_weights_in_init: bool = False by default; use getattr to guard against configs that lack this field:

-           skip_create_weights_in_init=config.skip_create_weights_in_init,
+           skip_create_weights_in_init=getattr(config, "skip_create_weights_in_init", False),
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0fee8cd and 62e56ac.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (2 hunks)

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17953 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13456 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17988 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17988 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13482 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18017 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18017 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13504 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18064 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18064 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13538 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

1 similar comment
@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18092 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18092 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13559 completed with status: 'FAILURE'

@tomeras91
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18155 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18155 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13603 completed with status: 'SUCCESS'

@tomeras91 tomeras91 enabled auto-merge (squash) September 9, 2025 08:42
@tomeras91 tomeras91 merged commit 6e712dd into NVIDIA:main Sep 9, 2025
7 checks passed
gergely-magyar pushed a commit to gergely-magyar/TensorRT-LLM that referenced this pull request Sep 9, 2025
…NVIDIA#7589)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Gergely Magyar <gergely.magyar@visma.com>
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…NVIDIA#7589)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants