KEMBAR78
[None][feat] Revise the calculation related to TileN in routing of MOE TRTLLM backend by ChristinaZ · Pull Request #8148 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@ChristinaZ
Copy link
Collaborator

@ChristinaZ ChristinaZ commented Oct 5, 2025

Summary by CodeRabbit

  • New Features

    • Added support for configurable tile-based token routing, enabling non–power-of-two tiling alongside existing power-of-two paths.
    • Expanded routing flexibility across expert permutations, offsets, and index sizing with automatic path selection.
  • Bug Fixes

    • Replaced hard runtime failures with safer fallbacks in configuration detection, improving stability.
    • Removed overly strict padding constraints, allowing broader valid configurations.
  • Tests

    • Updated unit tests to include tile token dimension and compute capability parameters.
    • Expanded coverage to validate both power-of-two and tile-based routing paths.

Description

Before TRTLLM backend use mPaddingLog2 to accelerate related calculation with function like divUpLog2(), mulLog2() and divUpMulLog2(). However, now the tileN might be a value like 192, which is not a power of 2. So I have to replace them with functions like mulTileN(), divUpTileN(), and divUpMulTileN().

About the performance, I tried to compare the performance. In general, with this modification, its running time extended slightly. For example, with mPaddingLog2=3 (mTileN=8), kernel routingRenormalize::routingIndicesClusterKernel can observe 3% performance regression.

So I think it's better to add one more template parameter so that it can still use the previous variable mPaddingLog2.

Test Coverage

./tests/unit_tests/kernels/routingKernelsTest
pytest -v -s tests/unittest/_torch/thop/parallel/test_moe.py

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@ChristinaZ
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20647 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20647 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #15593 completed with status: 'FAILURE'

@ChristinaZ
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20700 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20700 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15638 completed with status: 'FAILURE'

@ChristinaZ
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20732 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20732 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15666 completed with status: 'FAILURE'

@ChristinaZ
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20778 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20778 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15704 completed with status: 'SUCCESS'

@ChristinaZ ChristinaZ self-assigned this Oct 9, 2025
@ChristinaZ ChristinaZ marked this pull request as ready for review October 9, 2025 02:12
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 9, 2025

📝 Walkthrough

Walkthrough

Adds tile-based tiling alongside power-of-two padding across routing kernels. Introduces mTileTokensDim and isPow2 template parameter. Implements tile arithmetic helpers and switches computations (CTA counts, limits, offsets, sizes) between pow2 and tile paths. Updates launch macros, runner propagation, and unit tests to use tileTokensDim and revised parameterization.

Changes

Cohort / File(s) Summary of changes
Launch macros and dispatch
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
Adds LAUNCH_TILEN macro keyed on mPaddingLog2; replaces LAUNCH_PDL with LAUNCH_TILEN in routing dispatch branches, affecting dtype/expW and extra-flag paths.
Kernel parameterization and data model
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
Adds mTileTokensDim to DataBase/KernelParamsBase; introduces template bool isPow2_ across KernelParams; exposes static constexpr isPow2; defaults mPaddingLog2 to -1; setBaseParams now propagates mTileTokensDim.
Tile arithmetic helpers (device headers)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
Adds mulTileN, divUpTileN, divUpMulTileN; switches numCta, mnLimit, offsets, permutedIdxSize to pow2 vs tile branches via constexpr; preserves existing pow2 behavior.
Routing kernels: DeepSeek
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
Adds isPow2-controlled paths for numCta, mnLimit, offsets, permutedIdxSize using TileN variants; removes strict padding check; aligns exclusive-sum/index wiring with dual tiling.
Routing kernels: Llama4
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
Branches computations on isPow2 for numCta, mnLimit, offsets, permutedIdxSize; updates finalExpertOffset calculations; removes padding-log2 < 8 check; retains overall flow.
Routing kernels: Renormalize
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
Applies isPow2-based branching for counts/limits and permutedIdxSize; updates ExclusiveSum inputs via TileN variants; removes padding check.
Runner and config propagation
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
computeLog2 now returns -1 on non-pow2; propagates routingData.mTileTokensDim for DeepSeekV3, Llama4, Renormalize paths.
Test infra: helpers and params
cpp/tests/unit_tests/kernels/routing/routingTest.h
Adds host/device mulTileN/divUpTileN/divUpMulTileN; extends RoutingKernelTestParam with tileTokensDim and requiredComputeCapability (defaulted); propagates mTileTokensDim in setCommonParams.
Test logic updates to tile path
cpp/tests/unit_tests/kernels/routing/routingTest.cpp
Replaces paddingLog2-based math with tileTokensDim-based (sizes, prefix sums, CTA counts, limits).
Unit tests: param wiring
cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp, cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
Adds tileTokensDim argument (8) to RoutingKernelTestParam calls across tests.
Unit tests: Llama4 params
cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
Adds tileTokensDim (8) and requiredComputeCapability to RoutingKernelTestParam calls.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Host as Host (runner.cu)
  participant KP as KernelParamsBase/Kernels
  participant Math as Tile/Pow2 helpers
  participant Kern as Routing Kernels

  Host->>KP: setBaseParams(data)\n(mTileTokensDim, mPaddingLog2)
  Note over KP: KP::isPow2 (template constexpr)

  KP->>Kern: launch routing kernels
  alt KP::isPow2 == true
    Kern->>Math: divUpLog2/mulLog2 for\nnumCta, mnLimit, offsets, sizes
  else KP::isPow2 == false
    Kern->>Math: divUpTileN/mulTileN for\nnumCta, mnLimit, offsets, sizes
  end

  Kern-->>Host: results (permutedIdx, sizes)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description has a clear “Description,” “Test Coverage,” and checklist, but it still contains the template instructions and the @coderabbitai summary placeholder without an actual summary, so it does not fully follow the required template structure. Please replace the @coderabbitai summary placeholder with a concise summary of the changes, remove the template instruction block at the top, and ensure the description begins with the filled‐in summary followed by the required sections.
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title clearly summarizes the main change of replacing padding‐based calculations with tileN‐aware helpers in the MOE TRTLLM backend, follows the repository’s “[None][type] Summary” template, and is concise and specific enough for a reviewer to understand the primary update without extraneous details.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9298f1b and 6dbbe92.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h (2 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu (2 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh (7 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h (7 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu (3 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu (2 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu (4 hunks)
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp (8 hunks)
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp (6 hunks)
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp (10 hunks)
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp (5 hunks)
  • cpp/tests/unit_tests/kernels/routing/routingTest.h (4 hunks)
🧰 Additional context used
📓 Path-based instructions (7)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
  • cpp/tests/unit_tests/kernels/routing/routingTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
  • cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
  • cpp/tests/unit_tests/kernels/routing/routingTest.h
🧠 Learnings (3)
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
PR: NVIDIA/TensorRT-LLM#7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.

Applied to files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
📚 Learning: 2025-08-20T07:43:36.447Z
Learnt from: ChristinaZ
PR: NVIDIA/TensorRT-LLM#7068
File: cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh:169-172
Timestamp: 2025-08-20T07:43:36.447Z
Learning: In TensorRT-LLM MOE kernels, when processing up to 128 experts across 32 threads, each thread handles at most 4 experts (N < 5 constraint), where N represents candidates per thread rather than total system capacity.

Applied to files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu
🧬 Code graph analysis (5)
cpp/tests/unit_tests/kernels/routing/routingTest.cpp (1)
cpp/tests/unit_tests/kernels/routing/routingTest.h (5)
  • tileTokensDim (238-238)
  • topK (231-231)
  • divUpMulTileN (89-92)
  • divUpTileN (82-85)
  • mulTileN (75-78)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu (1)
cpp/tests/unit_tests/kernels/routing/routingTest.h (6)
  • divUpLog2 (59-62)
  • divUpTileN (82-85)
  • divUpMulLog2 (67-70)
  • divUpMulTileN (89-92)
  • mulLog2 (51-54)
  • mulTileN (75-78)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu (1)
cpp/tests/unit_tests/kernels/routing/routingTest.h (4)
  • divUpLog2 (59-62)
  • divUpTileN (82-85)
  • mulLog2 (51-54)
  • mulTileN (75-78)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu (1)
cpp/tests/unit_tests/kernels/routing/routingTest.h (6)
  • divUpLog2 (59-62)
  • divUpTileN (82-85)
  • mulLog2 (51-54)
  • mulTileN (75-78)
  • divUpMulLog2 (67-70)
  • divUpMulTileN (89-92)
cpp/tests/unit_tests/kernels/routing/routingTest.h (1)
cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu (1)
  • __host__ (589-592)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
@ChristinaZ
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21474 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21474 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16213 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@byshiue byshiue merged commit db1c271 into NVIDIA:main Oct 16, 2025
5 checks passed
govind-ramnarayan pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Oct 21, 2025
…E TRTLLM backend (NVIDIA#8148)

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants