KEMBAR78
[https://nvbugs/5434424][fix] A quick fix for the wrong output issue of SM89 blocked scaling batched GEMM when the input tensor is non-contiguous. by StudyingShao · Pull Request #7615 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@StudyingShao
Copy link
Collaborator

@StudyingShao StudyingShao commented Sep 8, 2025

Summary by CodeRabbit

  • Bug Fixes
    • Prevents stale data during FP8 GEMM prefetching, improving numerical correctness on some GPUs.
    • Aligns and reshapes FP8 batched quantization scales without changing interfaces.
    • Uses a temporary buffer for FP8 block-scaled BMM on SM 89/90 to avoid incorrect writes.
    • Enforces tensor contiguity when loading FP8 MoE weights to prevent intermittent CPU unpack issues.
  • Stability
    • Enhances reliability and determinism across supported GPU architectures with no public API changes.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
@StudyingShao StudyingShao requested review from a team as code owners September 8, 2025 11:28
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 8, 2025

📝 Walkthrough

Walkthrough

Adds shared-memory zeroing for an additional tile in an FP8 GEMM kernel, reshapes a scale tensor to a 3D, M-aligned layout in a quantization path, introduces a temporary output buffer for SM89/90 in FP8 block-scaling BMM, and enforces tensor contiguity before CPU-side unpacking in fused MoE quantized weight loading.

Changes

Cohort / File(s) Summary
FP8 GEMM kernel init
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
Clears shared-memory tile tBsSFB during prefetch, alongside existing tAsA, tBsB, tAsSFA zeroing; no API or broader control-flow changes.
FP8 quantize scale layout
cpp/tensorrt_llm/thop/fp8Quantize.cpp
In fp8_batched_quantize_1x128_permute102, aligns M to 4 and reshapes scaleFP8SF from 1D to 3D {b, m_4_align, elementSize/b/m_4_align}; element count unchanged; kernel interface and return type shape semantics updated (scale now 3D).
Attention FP8 BMM temp buffer (SM89/90)
tensorrt_llm/_torch/modules/attention.py
For SM 89/90, allocates a temporary output, invokes TRT op to write into it, then copies back to out; SM100 path unchanged; function signature unchanged.
Fused MoE FP8 unpack contiguity
tensorrt_llm/_torch/modules/fused_moe/quantization.py
Ensures .contiguous() after .cpu() for w31 and w2 shards before unpacking; no signature or downstream logic changes.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Attn as fp8_block_scaling_bmm_out
  participant TRT as TRT FP8 BMM
  participant Tmp as Temp Output

  Caller->>Attn: call(mat1, mat2_fp8, mat2_scale, out, ...)
  alt SM 89/90 path
    Attn->>Attn: allocate Tmp = out.new_empty(...)
    Attn->>TRT: bmm(mat1, mat2_fp8, mat2_scale, Tmp)
    TRT-->>Attn: Tmp filled
    Attn->>Attn: out.copy_(Tmp)
  else other SM paths
    Attn->>TRT: bmm(..., out)
    TRT-->>Attn: out filled
  end
  Attn-->>Caller: out
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10–15 minutes

Suggested labels

Community want to contribute

Suggested reviewers

  • PerkzZheng
  • litaotju
  • chenfeiz0326
  • Barry-Delaney
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)

1108-1115: Ensure contiguity for both W31 and W2 CPU-side unpacking (SM89 path).

You added .contiguous() for w2_weight_shard.cpu() before unpacking, but w31_weight_shard.cpu() above still lacks it. Non-contiguous CPU tensors can cause incorrect unpack or extra copies. Mirror the fix for W31.

-            w31_weight_shard = packer(
-                unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
+            w31_weight_shard = packer(
+                unpacker(w31_weight_shard.cpu().contiguous()).T.contiguous()).to(
                     w31_weight_shard.device)
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/attention.py (1)

571-579: Avoid alias/stride issues by using a temporary output; consider gating to reduce overhead.

The temp buffer fixes wrong results when out (or inputs) are strided/aliased. For performance, allocate the temp only when necessary (e.g., if out is non-contiguous or aliases inputs).

-        output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
-        torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
-                                                   mat1_scale, mat2_scale, output)
-        out.copy_(output)
+        need_temp = (not out.is_contiguous()) or (out.data_ptr() in (mat1.data_ptr(), mat2_fp8.data_ptr()))
+        target = out.new_empty(out.shape, dtype=out.dtype, device=out.device) if need_temp else out
+        torch.ops.trtllm.fp8_block_scaling_bmm_out(
+            mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, target)
+        if need_temp:
+            out.copy_(target)
cpp/tensorrt_llm/thop/fp8Quantize.cpp (1)

1-15: Update header year per repo guidelines.

Please include the current year in the copyright range.

-/* 
- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+/*
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dd9627d and 6eed0cc.

📒 Files selected for processing (4)
  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh (1 hunks)
  • cpp/tensorrt_llm/thop/fp8Quantize.cpp (1 hunks)
  • tensorrt_llm/_torch/modules/attention.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (7)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
  • cpp/tensorrt_llm/thop/fp8Quantize.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
  • cpp/tensorrt_llm/thop/fp8Quantize.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • cpp/tensorrt_llm/thop/fp8Quantize.cpp
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • cpp/tensorrt_llm/thop/fp8Quantize.cpp
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/thop/fp8Quantize.cpp
🧠 Learnings (2)
📓 Common learnings
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
📚 Learning: 2025-08-19T03:35:20.866Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh (1)

262-267: Good catch: zeroing SFB tile avoids stale scale data.

Initializing tBsSFB alongside tAsA, tBsB, and tAsSFA addresses potential garbage read before the first cp.async completes.

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
@StudyingShao StudyingShao force-pushed the jiangs/1.1.0rc4/fix_sm89_fp8bmm branch from e3a4628 to fe4f7fc Compare September 8, 2025 11:45
@StudyingShao StudyingShao changed the title [TRTLLM-6874][https://nvbugs/5434424][fix] A quick fix for the wrong output issue of SM89 blocked scaling batched GEMM when the input tensor is non-contiguous. [https://nvbugs/5434424][fix] A quick fix for the wrong output issue of SM89 blocked scaling batched GEMM when the input tensor is non-contiguous. Sep 8, 2025
@StudyingShao
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18040 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18040 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13519 completed with status: 'SUCCESS'

@StudyingShao
Copy link
Collaborator Author

/bot reuse-pipeline

@byshiue byshiue enabled auto-merge (squash) September 9, 2025 07:54
@tensorrt-cicd
Copy link
Collaborator

PR_Github #18188 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18188 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #18040 for commit b07f2cb

@byshiue byshiue disabled auto-merge September 9, 2025 12:25
@byshiue byshiue enabled auto-merge (squash) September 9, 2025 12:26
@byshiue
Copy link
Collaborator

byshiue commented Sep 9, 2025

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18219 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18219 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #18040 for commit 444f449

@byshiue byshiue merged commit cc75939 into NVIDIA:main Sep 9, 2025
5 checks passed
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…of SM89 blocked scaling batched GEMM when the input tensor is non-contiguous. (NVIDIA#7615)

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants