KEMBAR78
[TRTLLM-5966][feat] Helix: make softmax stats pointer available to attention gen by MatthiasKohl · Pull Request #6865 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@MatthiasKohl
Copy link
Collaborator

@MatthiasKohl MatthiasKohl commented Aug 13, 2025

Summary by CodeRabbit

  • New Features

    • Enabled optional collection of attention softmax statistics across generation and prefill modes, providing max and normalization metrics per token/head.
  • Bug Fixes

    • Added validation for softmax stats tensor shape with clear error messages, preventing misconfiguration issues.
  • Refactor

    • Unified softmax statistics handling through a single parameter across paths for consistent behavior.
    • Aligned public parameter naming for clarity and consistency.

Description

This PR makes the softmax stas available for tllm-gen attention gen, instead of just context.

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 13, 2025

📝 Walkthrough

Walkthrough

Renamed and unified the softmax statistics parameter to softmax_stats, added it to shared enqueue parameters, and wired it through context and generation paths. Adjusted validation and printing, and updated FMHA/MLA code to consume the new field.

Changes

Cohort / File(s) Summary of changes
API and struct updates
cpp/tensorrt_llm/common/attentionOp.h
Added float2* softmax_stats to EnqueueParams; removed void* softmaxStatsPtr from EnqueueContextParams; updated string representation to reference softmax_stats.
Runtime wiring (common)
cpp/tensorrt_llm/common/attentionOp.cpp
Propagated softmax_stats through generation and FMHA paths; switched to underscore field naming; forwarded stats pointer in MLA generation and FMHA contexts.
THOP bindings and validation
cpp/tensorrt_llm/thop/attentionOp.cpp
Consumed optional softmax_stats_tensor; validated last dim == 2; stored as float2* in common_enqueue_params.softmax_stats; removed per-stage softmaxStatsPtr usage.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant THOP AttentionOp
  participant Common AttentionOp
  participant Kernel (FMHA/MLA)

  Caller->>THOP AttentionOp: enqueue(..., softmax_stats_tensor?)
  THOP AttentionOp->>THOP AttentionOp: validate tensor shape (last dim == 2)
  THOP AttentionOp->>Common AttentionOp: common_enqueue_params.softmax_stats = float2*
  Common AttentionOp->>Kernel (FMHA/MLA): launch with params.softmax_stats
  Kernel (FMHA/MLA)-->>Common AttentionOp: read/write softmax stats
  Common AttentionOp-->>Caller: completion
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • zhou-yuxin
  • bobboli
  • jinyangyuan-nvidia
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@MatthiasKohl MatthiasKohl force-pushed the user/mjoux/helix-attention-gen branch from 3951b54 to da4a317 Compare August 13, 2025 11:27
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
cpp/tensorrt_llm/common/attentionOp.h (2)

111-115: Clarify softmax_stats layout and constraints; the current comment is misleading

The buffer is validated as having a last dim of size 2 in the THOP bridge, but the comment here says shape [num_tokens, num_heads_q] and mentions “LSE/denominator,” which is ambiguous. Please document the expected shape and dtype to prevent misuse.

Apply this diff to tighten the documentation:

-        // optional when compute attention stats (MLA chunked prefill or Helix parallelism)
-        // this is a buffer of size [num_tokens, num_heads_q] with each element
-        // representing the max and LSE/denominator of the softmax values
+        // Optional: compute attention stats (MLA chunked prefill or Helix parallelism).
+        // Device buffer with shape [num_tokens, num_heads_q, 2], where the last dimension packs:
+        //   [0] = max, [1] = log-sum-exp (LSE) of the softmax denominator.
+        // The tensor must be float32 and contiguous with the last dimension exactly 2.
         float2* softmax_stats = nullptr;

182-182: Rename print label to match the new field name

The debug string label still uses the old camelCase name. Update it to keep diagnostics consistent.

Apply this diff:

-            ss << "softmaxStatsPtr: " << this->softmax_stats << std::endl;
+            ss << "softmax_stats: " << this->softmax_stats << std::endl;
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0fad602 and da4a317.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/common/attentionOp.cpp (2 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (2 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (1 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/tensorrt_llm/common/attentionOp.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
cpp/tensorrt_llm/common/attentionOp.cpp (2)

1773-1774: Good: Context FMHA now receives softmax stats

Passing params.softmax_stats into fmhaParams.softmaxStatsPtr for context FMHA is consistent with the new field.


1044-1046: FMHA-based runner already wires softmaxStatsPtr

The FMHA‐variant generation path already exposes and sets softmaxStatsPtr, so parity is ensured. No further changes needed:

  • In fmhaRunnerParams.h, TllmGenFmhaRunnerParams defines float2* softmaxStatsPtr;.
  • In fmhaDispatcher.cpp (line 203), tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmax_stats);

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15134 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15134 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11429 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Collaborator

@brb-nv brb-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@brb-nv brb-nv requested review from hlu1, liji-nv and yuxianq and removed request for liji-nv and yuxianq August 14, 2025 15:20
@hlu1 hlu1 requested review from PerkzZheng and yuxianq August 14, 2025 16:10
@MatthiasKohl MatthiasKohl force-pushed the user/mjoux/helix-attention-gen branch from da4a317 to 258082a Compare September 16, 2025 16:22
@MatthiasKohl
Copy link
Collaborator Author

@yuxianq @PerkzZheng I believe I addressed your comments, can you quickly review again, please?

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18812 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18812 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #14105 completed with status: 'FAILURE'

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18913 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18913 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #14179 completed with status: 'FAILURE'

@MatthiasKohl MatthiasKohl force-pushed the user/mjoux/helix-attention-gen branch from 258082a to 925c6be Compare September 17, 2025 10:01
@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18979 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18979 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #14228 completed with status: 'FAILURE'

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19003 [ run ] triggered by Bot

Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
@MatthiasKohl MatthiasKohl force-pushed the user/mjoux/helix-attention-gen branch from b7de48f to 98778f9 Compare September 17, 2025 14:35
@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19012 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19003 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14249 (Blue Ocean) completed with status: ABORTED

@MatthiasKohl
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19019 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19012 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14256 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19019 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14262 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator

brb-nv commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19036 [ run ] triggered by Bot

@brb-nv
Copy link
Collaborator

brb-nv commented Sep 17, 2025

Previous unrelated failure waived an hour ago.
#7812

@brb-nv
Copy link
Collaborator

brb-nv commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19040 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19036 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14277 (Blue Ocean) completed with status: ABORTED

@brb-nv brb-nv enabled auto-merge (squash) September 17, 2025 19:12
@tensorrt-cicd
Copy link
Collaborator

PR_Github #19040 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14283 completed with status: 'SUCCESS'

@brb-nv brb-nv merged commit 022d778 into NVIDIA:main Sep 17, 2025
5 checks passed
liji-nv added a commit to liji-nv/TensorRT-LLM that referenced this pull request Sep 19, 2025
liji-nv added a commit to liji-nv/TensorRT-LLM that referenced this pull request Sep 19, 2025
…le to attention gen (NVIDIA#6865)"

This reverts commit 022d778.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…tention gen (NVIDIA#6865)

Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
…tention gen (NVIDIA#6865)

Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
MatthiasKohl added a commit to MatthiasKohl/TensorRT-LLM that referenced this pull request Sep 30, 2025
…tention gen (NVIDIA#6865)

Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants