KEMBAR78
[OMNIML-2336][feat] add W4A8 NVFP4 FP8 fused moe by sychen52 · Pull Request #7968 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@sychen52
Copy link
Collaborator

@sychen52 sychen52 commented Sep 24, 2025

Summary by CodeRabbit

  • New Features
    • Added support for a new W4A8 NVFP4/FP8 quantization mode in Fused MoE, enabling TRT-LLM execution where applicable with fallback to the existing path if not supported.
    • Introduced a configuration property to detect whether this quantization mode is active.
    • Made scaling vector size configurable during weight creation, improving flexibility for quantized MoE setups.
    • Enhanced validation and execution paths to handle the new quantization mode seamlessly, including weight scale loading and post-load adjustments.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@sychen52 sychen52 self-assigned this Sep 24, 2025
@sychen52 sychen52 requested a review from a team as a code owner September 24, 2025 19:44
@sychen52 sychen52 requested a review from mikeiovine September 24, 2025 19:44
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 24, 2025

📝 Walkthrough

Walkthrough

Adds support for W4A8+NVFP4/FP8 quantization in Fused MoE: extends backend selection to trigger TRTLLMGen path, introduces a new W4A8NVFP4FP8 method class, updates quantization weight/scale creation and loading signatures, adds a new public capability flag, and wires a new forward path handling for the variant.

Changes

Cohort / File(s) Summary
Backend selection
tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Extends TRT-LLM path condition to include quant_mode.has_w4a8_nvfp4_fp8(); otherwise falls back to Cutlass with warning as before.
TRT-LLM Gen Fused MoE implementation
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Adds W4A8+NVFP4/FP8 branch: selection in _get_quant_method, config checks, weight setup, and forward_impl execution path with routing, optional allgather, fp8/fp4 block-scale MOE runner, and finalize handling. Imports new method class.
Interface/API
tensorrt_llm/_torch/modules/fused_moe/interface.py
Adds public property has_w4a8_nvfp4_fp8 mirroring existing quantization capability flags.
Quantization methods
tensorrt_llm/_torch/modules/fused_moe/quantization.py
- Updates NVFP4 method signatures to accept scaling_vector_size and per-SF element counts (defaults preserved).
- Introduces W4A8NVFP4FP8TRTLLMGenFusedMoEMethod with fixed 32-wide scaling, weight/scale creation, and post-load scale/alpha adjustments (÷6/×6).
- Adds corresponding load helpers delegating to base with explicit per-SF=32.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant MoE as MoE Module
  participant Selector as Backend Selector
  participant Quant as Quant Method Picker
  participant TRT as TRTLLMGen Runner

  User->>MoE: create/forward(config, quant_config)
  MoE->>Selector: choose backend
  alt has_w4a8_nvfp4_fp8
    Selector->>MoE: use TRTLLMGenFusedMoE
    MoE->>Quant: _get_quant_method()
    Quant-->>MoE: W4A8NVFP4FP8TRTLLMGenFusedMoEMethod
    MoE->>TRT: create_weights()/load_scales (32-wide)
    User->>MoE: forward(...)
    MoE->>TRT: forward_impl (route, optional allgather, fp8/fp4 block-scale run)
    TRT-->>MoE: outputs
    MoE-->>User: result
  else other modes
    Selector-->>MoE: existing path (unchanged)
  end
Loading
sequenceDiagram
  autonumber
  participant MoE as MoE.forward_impl
  participant Router as Router
  participant Comm as Allgather (optional)
  participant Kernel as FP8/FP4 MoE Kernel

  Note over MoE,Kernel: New W4A8+NVFP4/FP8 branch
  MoE->>Router: compute routing
  alt post-quant allgather enabled
    MoE->>Comm: allgather activations
    Comm-->>MoE: gathered activations
  end
  MoE->>Kernel: run block-scale MoE (nvfp4/fp8 scales, alphas)
  Kernel-->>MoE: expert outputs
  MoE-->>MoE: finalize (assemble states)
  MoE-->>User: return
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description is only the unfilled template with placeholder comments and does not include any actual summary of the changes, explanation of the issue or solution, or details on test coverage, leaving critical sections empty. This fails to meet the requirement to clearly explain what was done and why, as well as to list relevant tests that validate the new feature. Please replace the placeholder template text with a concrete summary and description of the changes, detail the specific tests added or modified to cover the new W4A8 NVFP4 FP8 fused MoE paths, and ensure the PR checklist reflects any updates to dependencies, documentation, or code ownership.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title follows the repository convention with a valid ticket identifier and type, and it succinctly summarizes the main change by indicating support for W4A8 NVFP4 FP8 fused MoE, which aligns with the code modifications. It is clear, specific, and directly relates to the primary feature added in this PR.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (2)

1481-1487: Guard against invalid scaling_vector_size; add asserts.

Hidden and intermediate sizes must be divisible by scaling_vector_size to avoid mis-shaped block scales and kernel mismatches.

Apply this diff:

 def create_weights(self,
                    module: torch.nn.Module,
                    weight_dtype,
                    weight_vec_size,
                    block_scales_dtype,
                    block_scales_vec_size,
-                   scaling_vector_size=16):
+                   scaling_vector_size=16):
 
-        module.scaling_vector_size = scaling_vector_size
+        module.scaling_vector_size = scaling_vector_size
+        assert module.hidden_size % module.scaling_vector_size == 0, (
+            f"hidden_size {module.hidden_size} must be divisible by scaling_vector_size "
+            f"{module.scaling_vector_size}"
+        )
+        assert module.intermediate_size_per_partition % module.scaling_vector_size == 0, (
+            f"intermediate_size_per_partition {module.intermediate_size_per_partition} must be divisible by "
+            f"scaling_vector_size {module.scaling_vector_size}"
+        )

2010-2058: Avoid hard-coded 32; use module.scaling_vector_size.

Reduces drift if the scaling width changes and keeps overrides consistent.

Apply this diff:

 class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):

     def create_weights(self, module: torch.nn.Module):
         weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
         block_scales_vec_size = 1

-        NVFP4FusedMoEMethod.create_weights(self, module, self.weight_dtype,
-                                           weight_vec_size,
-                                           self.block_scales_dtype,
-                                           block_scales_vec_size, 32)
+        NVFP4FusedMoEMethod.create_weights(self, module, self.weight_dtype,
+                                           weight_vec_size,
+                                           self.block_scales_dtype,
+                                           block_scales_vec_size,
+                                           32)  # sets module.scaling_vector_size

@@
-    def load_expert_w3_w1_weight_scale_nvfp4(
+    def load_expert_w3_w1_weight_scale_nvfp4(
             self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
             w3_weight_scale: torch.Tensor,
             dst_w3_w1_weight_scale: torch.Tensor):
         return super().load_expert_w3_w1_weight_scale_nvfp4(
-            module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale,
-            32)
+            module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale,
+            module.scaling_vector_size)

@@
-    def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
+    def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
                                           w2_weight_scale: torch.Tensor,
                                           dst_w2_weight_scale: torch.Tensor):
         return super().load_expert_w2_weight_scale_nvfp4(
-            module, w2_weight_scale, dst_w2_weight_scale, 32)
+            module, w2_weight_scale, dst_w2_weight_scale,
+            module.scaling_vector_size)
tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1)

49-53: Update warning message to include the new quant mode.

Keeps UX consistent with the actual supported set.

Apply this diff:

-            logger.warning(
-                "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. "
+            logger.warning(
+                "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. "
                 f"Check out details in quant_config: {quant_config}"
                 "Using CutlassFusedMoE instead.")
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (3)

115-116: Include new quant mode in the assertion message.

The condition allows it; the error text should reflect it.

Apply this diff:

-        assert self.has_deepseek_fp8_block_scales \
-            or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
-            or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
+        assert self.has_deepseek_fp8_block_scales \
+            or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
+            or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \
+            "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."

206-209: Consider enabling post‑quant allgather for W4A8 NVFP4 FP8.

If fp8_fp4_block_scale_moe_runner supports PQ allgather like nvfp4, include it here for performance parity.

Apply this diff if supported:

-        is_post_quant_allgather_supported = self.has_nvfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8
+        is_post_quant_allgather_supported = (
+            self.has_nvfp4
+            or self.has_w4a8_nvfp4_fp8
+            or self.has_w4a8_mxfp4_fp8
+            or self.has_w4a8_mxfp4_mxfp8
+        )

Please confirm the kernel supports PQ allgather; if not, ignore.


379-410: W4A8 NVFP4 FP8 forward: minor robustness and consistency tweaks.

  • Pass router logits/bias as None under PQ allgather (matches other branches).
  • Optional: verify whether padding to kernel alignment is ever required for this path.

Apply this diff:

-            outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
-                router_logits,
-                routing_bias,
+            outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
+                router_logits if not run_post_quant_allgather else None,
+                routing_bias if not run_post_quant_allgather else None,
                 hidden_states_fp8,
                 self.w3_w1_weight,
                 self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
                 self.w2_weight,
                 self.w2_weight_scale.view(torch.float8_e4m3fn),
                 self.fc31_scale_c.data,
                 self.fc31_alpha.data,
                 self.fc2_alpha.data,

Also please validate on a model where hidden_size is not a multiple of common tiles (e.g., 6144) that no input padding is required for this kernel. If padding is required, mirror the mxfp4-fp8 path’s padding pattern.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a65af2 and 5603f7f.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (5 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-08-21T02:39:12.009Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1475-1480
Timestamp: 2025-08-21T02:39:12.009Z
Learning: The min latency mode functionality in TensorRT-LLM MOE kernels (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu) is deprecated and no longer being maintained/updated, as confirmed by djns99. Bug reports and optimization suggestions for the computeStridesTmaWarpSpecializedLowLatencyKernel and related min latency code paths should be deprioritized.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/modules/fused_moe/create_moe.py (3)
tensorrt_llm/llmapi/llm_args.py (2)
  • quant_config (2299-2302)
  • quant_config (2305-2306)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
  • has_w4a8_nvfp4_fp8 (305-308)
tensorrt_llm/_torch/modules/linear.py (1)
  • has_w4a8_nvfp4_fp8 (1930-1933)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
tensorrt_llm/_torch/modules/linear.py (1)
  • has_w4a8_nvfp4_fp8 (1930-1933)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (3)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
  • W4A8NVFP4FP8TRTLLMGenFusedMoEMethod (2010-2057)
tensorrt_llm/_torch/modules/fused_moe/interface.py (2)
  • has_w4a8_nvfp4_fp8 (305-308)
  • _ (72-98)
tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py (4)
  • _ (335-371)
  • _ (623-646)
  • _ (1688-1732)
  • fp8_fp4_block_scale_moe_runner (1611-1684)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
  • create_weights (191-192)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
  • create_weights (344-352)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (7)
  • create_weights (204-225)
  • create_weights (394-435)
  • create_weights (673-734)
  • create_weights (1350-1357)
  • setup_quant_scales (227-228)
  • setup_quant_scales (437-443)
  • setup_quant_scales (736-746)
tensorrt_llm/_torch/modules/linear.py (10)
  • create_weights (220-223)
  • create_weights (278-288)
  • create_weights (321-344)
  • create_weights (488-508)
  • create_weights (589-612)
  • create_weights (721-761)
  • create_weights (916-956)
  • create_weights (1108-1131)
  • create_weights (1220-1239)
  • create_weights (1339-1366)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)

304-309: Property addition looks correct and consistent.

Mirrors the linear module’s API and existing quant flags. No issues.

tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1)

44-46: Backend gate extension looks good.

Includes the new has_w4a8_nvfp4_fp8() in TRTLLMGen selection.

@sychen52
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19840 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19840 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14929 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19845 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19845 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14934 completed with status: 'FAILURE'

@sychen52 sychen52 requested review from DomBrown and hyukn September 25, 2025 04:33
@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19887 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19887 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14967 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19986 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19986 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15048 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19999 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19999 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15061 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20033 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20033 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15091 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20109 [ run ] triggered by Bot

Copy link
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test?

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20384 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20384 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15379 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20396 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20396 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15390 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20408 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20408 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15399 completed with status: 'FAILURE'

@sychen52
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20425 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20425 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15412 completed with status: 'FAILURE'

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52
Copy link
Collaborator Author

sychen52 commented Oct 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20438 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20438 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User sychen52 is not in the prioritized list. L0 testing cannot be triggered.

@sychen52
Copy link
Collaborator Author

sychen52 commented Oct 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20442 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20442 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User sychen52 is not in the prioritized list. L0 testing cannot be triggered.

@sychen52
Copy link
Collaborator Author

sychen52 commented Oct 1, 2025

/bot skip --comment "previous run failed for a known unrelated bug"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20443 [ skip ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20443 [ skip ] completed with state SUCCESS
Skipping testing for commit 4b47b67

@sychen52 sychen52 merged commit ba8abea into NVIDIA:main Oct 1, 2025
5 checks passed
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 2, 2025
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Signed-off-by: Faradawn Yang <faradawny@gmail.com>
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull request Oct 3, 2025
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 3, 2025
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Signed-off-by: Faradawn Yang <faradawny@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants