KEMBAR78
[None][feat] Apply AutoTuner to fp8_block_scale_deep_gemm to trigger JIT ahead of time. by hyukn · Pull Request #7113 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@hyukn
Copy link
Collaborator

@hyukn hyukn commented Aug 21, 2025

Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag.

Summary by CodeRabbit

  • New Features

    • Added a fused FP8 GEMM operation exposed as a Torch custom op with autotuning, dynamic-shape support, and configurable casting; returns outputs in the requested dtype.
  • Performance

    • Streamlines FP8 linear execution by removing intermediate buffers and fusing quantization + GEMM, reducing memory and improving efficiency.
  • Refactor

    • Replaced multi-step FP8 path in linear layers with the new fused operation.
  • Tests

    • Updated unit tests to validate the fused FP8 path under autotune.

@hyukn hyukn requested review from a team as code owners August 21, 2025 06:52
@hyukn hyukn requested review from QiJune and yizhang-nv August 21, 2025 06:52
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 21, 2025

📝 Walkthrough

Walkthrough

Adds a tunable Torch custom op trtllm::fp8_swap_ab_gemm (runner, tuning-buckets, and fake registration), integrates it into FP8BlockScalesLinearMethod replacing previous per-token quant + deep_gemm path, and updates a unit test to use the new op within an autotune context. Duplicate op definitions present.

Changes

Cohort / File(s) Summary
Torch custom op & tuner
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Adds fp8SwapABGemmRunner (TunableRunner), fp8_swap_ab_gen_tuning_buckets, public op fp8_swap_ab_gemm and a register_fake path. Implements per-token FP8 quantization via fp8_utils, calls deep_gemm.fp8_gemm_nt, integrates AutoTuner-driven tactic selection and output-dtype handling. Note: identical definitions appear duplicated.
Linear module integration
tensorrt_llm/_torch/modules/linear.py
Replaces explicit per-token FP8 quant + deep_gemm.fp8_gemm_nt with fused torch.ops.trtllm.fp8_swap_ab_gemm(..., disable_ue8m0_cast=True). Removes fp8_utils import and intermediate FP8 buffers.
Unit test update
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
Replaces deep_gemm call with torch.ops.trtllm.fp8_swap_ab_gemm inside with autotune(). Removes per_token_cast_to_fp8_e8m0 import/usage and relies on op to allocate output.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller as Linear/Test
  participant AT as AutoTuner
  participant OP as trtllm::fp8_swap_ab_gemm
  participant Q as fp8_utils
  participant G as deep_gemm.fp8_gemm_nt

  Caller->>OP: call fp8_swap_ab_gemm(input, weight, weight_scale, opts)
  OP->>AT: gen tuning buckets / request tactic
  AT-->>OP: selected tactic (e.g., 0)

  rect rgba(220,235,255,0.25)
    OP->>Q: per-token quantize input -> A_fp8, A_sf
    Q-->>OP: A_fp8, A_sf
  end

  rect rgba(220,255,220,0.25)
    OP->>G: deep_gemm.fp8_gemm_nt(A_fp8, weight, weight_scale, ...)
    G-->>OP: output (respected output_dtype)
  end

  OP-->>Caller: return output

  alt Graph/tracing path
    note right of OP: register_fake returns placeholder tensor with synthetic shape
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • lfr-0531
  • limin2021

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 21, 2025

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/linear.py (1)

593-598: Switch to fused fp8_swap_ab_gemm looks good; consider surfacing output dtype and verify accuracy parity

  • The fused call will help pre-warm kernels via AutoTuner. Two nits:
    • Consistency: elsewhere we honor module.dtype or input.dtype for output. Here we rely on the op’s default (bfloat16). If module.dtype is set (e.g., fp16 scenarios), consider passing it through explicitly.
    • Semantics: This replaces 1x128 activation quantization with per-token quantization. Please confirm expected accuracy parity on SM100 across the swept shapes.

Suggested change to preserve dtype behavior:

-                output = torch.ops.trtllm.fp8_swap_ab_gemm(
-                    input,
-                    module.weight,
-                    module.weight_scale,
-                    disable_ue8m0_cast=True,
-                )
+                output = torch.ops.trtllm.fp8_swap_ab_gemm(
+                    input,
+                    module.weight,
+                    module.weight_scale,
+                    output_dtype=(module.dtype or input.dtype),
+                    disable_ue8m0_cast=True,
+                )
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)

25-26: Good: autotune context added to pre-trigger JIT

Wrapping the call in autotune() enables ahead-of-time tactic selection and compilation; this aligns with the PR goal of reducing inference-time JIT. Consider setting TRTLLM_AUTOTUNE_LOG=1 in CI for visibility of tuned buckets.


54-60: Test now exercises the new fused op; add a dtype assertion and broaden shapes if you want stronger coverage

The check is correct. Two small suggestions:

  • Assert output.dtype to guard against accidental default changes.
  • Optionally parametrize disable_ue8m0_cast to exercise both branches.

Example adjustments:

 with autotune():
     output = torch.ops.trtllm.fp8_swap_ab_gemm(
         a,
         act_b_fp8,
         act_b_sf,
     )
+assert output.dtype == torch.bfloat16

Optionally:

-    with autotune():
-        output = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf)
+    with autotune():
+        output = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf)
+    with autotune():
+        output2 = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf, disable_ue8m0_cast=True)
+    torch.testing.assert_close(output, output2, atol=2e-2, rtol=2e-2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

920-926: Class naming is inconsistent with the rest; add runner caching and round rule to tuning_config

  • Naming: other runners use PascalCase (FP8RowwiseGemmRunner, FP4GemmRunner, …). Please rename to FP8SwapABGemmRunner for consistency.
  • Caching: other runners maintain a runner_dict; add the same to avoid re-instantiation overhead.
  • Tuning config: supply the round rule introduced above.

Proposed update:

-class fp8SwapABGemmRunner(TunableRunner):
-    tuning_config = TuningConfig(
-        dynamic_tensor_specs=(DynamicTensorSpec(
-            0, 0, fp8_swap_ab_gen_tuning_buckets), ),
-        tune_max_num_tokens=4096,
-    )
+class FP8SwapABGemmRunner(TunableRunner):
+    runner_dict = dict()
+    tuning_config = TuningConfig(
+        dynamic_tensor_specs=(DynamicTensorSpec(
+            0, 0, fp8_swap_ab_gen_tuning_buckets, _fp8_swap_ab_round_rule), ),
+        tune_max_num_tokens=4096,
+    )
 
-    def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
-        self.output_dtype = output_dtype
-        self.disable_ue8m0_cast = disable_ue8m0_cast
+    def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
+        self.output_dtype = output_dtype
+        self.disable_ue8m0_cast = disable_ue8m0_cast
+        key = (output_dtype, disable_ue8m0_cast)
+        if key not in FP8SwapABGemmRunner.runner_dict:
+            FP8SwapABGemmRunner.runner_dict[key] = self

Note: subsequent references to fp8SwapABGemmRunner need to be updated to FP8SwapABGemmRunner.


962-987: Update to renamed runner and propagate tune_max_num_tokens cleanly

After renaming the class and adding round rule, update the references:

-    fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
+    fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner(
         output_dtype,
         disable_ue8m0_cast,
     )
-    fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
+    FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens

Also consider guarding tune_max_num_tokens to be ≥8 and multiple of 8 for SM100-friendly bucketing.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2d40e87 and c7d8c80.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/modules/linear.py (1 hunks)
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (1)

22-22: Import update is fine

Using per_block_cast_to_fp8_e8m0 for the weight path aligns with DeepGEMM’s expectations. No concerns here.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

7-8: New imports are appropriate

Importing fp8_utils and deep_gemm is expected for the fused FP8 path. No issues spotted.


989-999: Fake kernel shape/dtype is correct

The fake registration returns [M, N] with the requested dtype, matching execution semantics. Looks good.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16019 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16019 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12041 completed with status: 'FAILURE'

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 22, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16122 [ run ] triggered by Bot

…JIT ahead of time.

Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag.
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
@hyukn hyukn force-pushed the feat/fp8_deep_gemm_autotuning branch from c7d8c80 to 188e2fa Compare August 22, 2025 05:33
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 22, 2025

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)

895-900: Include upper-bound bucket and add deterministic rounding rule (pre-JIT coverage + determinism).

As previously noted, buckets exclude x itself (e.g., x=4096 → max 3968) and there’s no round rule for arbitrary M. This can under-tune at caps and cause non-deterministic selection near small M. Add an upper-bound-inclusive step and wire a floor-to-bucket round rule. Also cache results and add a precise return type.

Apply:

-def fp8_swap_ab_gen_tuning_buckets(x: int):
-    buckets = tuple(range(8, 128, 8))
-    if x >= 128:
-        buckets += tuple(range(128, x, 128))
-    return buckets
+@lru_cache(maxsize=None)
+def fp8_swap_ab_gen_tuning_buckets(x: int) -> Tuple[int, ...]:
+    # Base buckets: 8, 16, ..., 120
+    buckets = tuple(range(8, 128, 8))
+    if x >= 128:
+        # Include x when exactly divisible by 128 to avoid under-tuning at the cap
+        buckets += tuple(range(128, x + 1, 128))
+    return buckets
+
+def _fp8_swap_ab_round_rule(x: int) -> int:
+    # Floor to bucket:
+    #  - x < 128: nearest lower multiple of 8 (min 8)
+    #  - x >= 128: nearest lower multiple of 128 (min 128)
+    if x < 128:
+        return max(8, (x // 8) * 8)
+    return max(128, (x // 128) * 128)

903-907: Wire the round rule into the DynamicTensorSpec.

Without a round rule, AutoTuner may map small M unevenly. Use the rule introduced above.

-    tuning_config = TuningConfig(
-        dynamic_tensor_specs=(DynamicTensorSpec(
-            0, 0, fp8_swap_ab_gen_tuning_buckets), ),
-        tune_max_num_tokens=4096,
-    )
+    tuning_config = TuningConfig(
+        dynamic_tensor_specs=(DynamicTensorSpec(
+            0, 0, fp8_swap_ab_gen_tuning_buckets, _fp8_swap_ab_round_rule), ),
+        tune_max_num_tokens=4096,
+    )

913-920: Make swap_ab truly tunable (align with PR intent).

The PR claims swap_ab remains tunable, but only [0] is returned, so it isn’t. Return both tactics.

-        # Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
-        return [0]
+        # 0 => swap_ab=False, 1 => swap_ab=True
+        return [0, 1]

921-941: Propagate tactic to deep_gemm and fix typo.

Forward ignores tactic and never sets swap_ab; also a minor typo “detemrmine”. Derive swap_ab from tactic and pass into deep_gemm.

         output = torch.empty(
             (input.size(0), weight.size(0)),
             device=input.device,
             dtype=self.output_dtype,
         )
-        # TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
-        # Treat the default tactic=-1 as swap_ab=False
+        # Treat the default tactic=-1 as swap_ab=False; tactic==1 => swap_ab=True
+        swap_ab = (tactic == 1)
         deep_gemm.fp8_gemm_nt(
             (a, a_sf),
             (weight, weight_scale),
             output,
+            swap_ab=swap_ab,
             disable_ue8m0_cast=self.disable_ue8m0_cast,
         )

To confirm the argument is supported across the repo, run:

#!/bin/bash
# Verify deep_gemm.fp8_gemm_nt supports 'swap_ab' and check all call sites.
set -euo pipefail

echo "== Search for fp8_gemm_nt definition or bindings =="
rg -nP -C3 '\bfp8_gemm_nt\b' --type cpp --type c --type h --type hpp --type py

echo "== Search for call sites of fp8_gemm_nt =="
rg -nP -C2 'fp8_gemm_nt\s*\('
🧹 Nitpick comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

909-912: Rename class to PascalCase for consistency with guidelines.

Python class names should use PascalCase. Rename and update references.

-class fp8SwapABGemmRunner(TunableRunner):
+class FP8SwapABGemmRunner(TunableRunner):
     tuning_config = TuningConfig(
         ...
     )

Also update references below (outside this hunk):

-    fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
+    fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner(
         output_dtype,
         disable_ue8m0_cast,
     )
-    fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
+    FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens

945-969: Minor: keep class name change consistent and set tune_max before choose_one.

Use the PascalCase class name consistently; also setting tune_max before choose_one is good—keep it that way.

-    fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
+    fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner(
         output_dtype,
         disable_ue8m0_cast,
     )
-    fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
+    FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between c7d8c80 and 188e2fa.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/modules/linear.py (1 hunks)
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
  • tensorrt_llm/_torch/modules/linear.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

7-8: Imports look correct and align with downstream usage.

fp8_utils and deep_gemm are both used in the new FP8 path; no import leaks or shadowing.


971-981: Fake registration signature and shape inference LGTM.

Signature mirrors the real op; output [M, N] with N=weight.size(0) matches fp8_gemm_nt(nt path) expectation.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16131 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16122 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16131 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #12134 completed with status: 'ABORTED'

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 22, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16138 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16138 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12139 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Member

@yizhang-nv yizhang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hyukn hyukn merged commit 9c5b464 into NVIDIA:main Aug 25, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants