-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] support JIT mha.cu for SPEC_DEC in runtime #6078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][feat] support JIT mha.cu for SPEC_DEC in runtime #6078
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #11986 [ run ] triggered by Bot |
|
PR_Github #11986 [ run ] completed with state |
6e60be7 to
bf8aaa8
Compare
📝 WalkthroughWalkthroughThe changes introduce support for a new HMMA kernel type in the decoder masked multi-head attention logic, update the kernel selection criteria to include this kernel, and refine the masking logic in the attention kernel to use a more explicit lowest floating-point value. No public interfaces or exported entity declarations were altered. Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as DecoderXQARunner
participant ImplJIT as DecoderXQAImplJIT
participant Kernel as HMMA Kernel
Runner->>Runner: getImplFromXQAParams(params)
alt Supported by Hopper/MLA/Ampere XQA
Runner->>ImplJIT: Use JIT Implementation
else Not Supported
Runner->>ImplJIT: Use Precompiled Implementation
end
ImplJIT->>ImplJIT: runImpl(...)
alt Speculative Decoding & HMMA Kernel
ImplJIT->>Kernel: Prepare and launch HMMA kernel with parameters
else Other Kernel Types
ImplJIT->>Kernel: Launch other kernel types as before
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (3)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp (1)
81-90: Clarify and tighten Ampere XQA support in decoderXQARunnerThe
supportedByAmpereXqacondition currently only tests!xqaParams.isMLA(), enabling JIT for all non-MLA cases (regardless of spec-dec, SM version, or data type). Please:• Update the comment above to reflect actual Ampere XQA support (pre-compiled cubins vs. JIT).
• RestrictsupportedByAmpereXqato spec-dec mode:xqaParams.multi_query_tokens.
• Limit to Ampere SM versions (e.g. 80, 86, 87):(smVersion == 80 || smVersion == 86 || smVersion == 87).
• (If needed) Restrictkv_cache_data_typesimilarly to Hopper’s E4M3 requirement.Suggested diff:
- bool const supportedByAmpereXqa = (!xqaParams.isMLA()); + bool const supportedByAmpereXqa = + (xqaParams.multi_query_tokens && + (smVersion == 80 || smVersion == 86 || smVersion == 87) && + /* optional: xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 */);File: cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
Lines: ~84–90
🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp (2)
404-406: Fix formatting of ternary operatorThe ternary operator formatting is inconsistent with the rest of the codebase. The comment placement is also confusing.
- unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter - xqaParams.spec_decoding_max_generation_length - : qSeqLen; + // true for ReDrafter + unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable + ? xqaParams.spec_decoding_max_generation_length + : qSeqLen;
395-445: Consider extracting HMMA kernel launch logic into a separate methodThe new code block for HMMA kernel launch is quite large and complex. For better maintainability and readability, consider extracting this logic into a separate method similar to how MLA kernel launch is handled.
This would improve:
- Code organization and readability
- Easier unit testing of the HMMA kernel launch logic
- Consistency with the existing code structure
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
cpp/kernels/xqa/mha.cu(1 hunks)cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp(2 hunks)cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
cpp/kernels/xqa/mha.cu (1)
508-509: Change avoidsNaN, but double-check soft-max semantics when an entire row is maskedSwitching from the literal
-INFINITYtomha::numeric_limits<float>::lowest()prevents the(-∞) – (-∞)pattern that previously producedNaNduring
elem*log2e - bias(good). However, becauselowest()is finite, a fully-masked row will now receive a finite maximum and will come out of the soft-max with a uniform probability mass (each element becomesexp2(0)=1). If the intent is to force such rows to contribute zero to the attention output, you will need an extra guard (e.g., detectmaskFlag == falsefor all cols and set the row sum to zero afterwards).Action items
- Confirm that “all-masked” rows cannot occur, or that the new behaviour is acceptable.
- Ensure
mha::numeric_limitsis specialised with__device__ __host__to avoid host-only instantiation issues inside device code.Would you like a small unit test that exercises the fully-masked-row edge case?
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp (2)
212-217: LGTM! Proper support for HMMA kernel in speculative decodingThe addition of
isHMMAKerneldetection and its inclusion in the spec dec support check is implemented correctly. This aligns with the PR objective of supporting JIT mha.cu for SPEC_DEC in runtime.
359-445: Review HMMA kernel launch parameters and dimensionsThe
kernel_mhadefinition incpp/kernels/xqa/mha.cu(SPEC_DEC path) declares arguments in this order:
- qSeqLen
- num_k_heads
- headGrpSize
- SeqLenDataType const* qCuSeqLens
- (optional) uint32_t slidingWinSize
- float qScale
- OutputHead* output
- (optional) float const* rcpOutScale
- IOHead const* q
- MaskType const* mask
- KVCacheList cacheList
- (optional) BeamSearchParams beamSearchParams
- uint32_t batchSize
- float const* kvCacheScale
- uint32_t* semaphores
- void* scratch
In the JIT path (
else if (isSpecDec && isHMMAKernel)indecoderXQAImplJIT.cpp):
- Ensure that every
appendParam(&…)call lines up exactly with one of the entries above.- Confirm you’re pushing exactly one pointer per non-default argument, in the same order.
- Verify that you account for optional parameters only when their corresponding compile-time flags or runtime conditions match (e.g. slidingWindowSize, rcpOutScale, beamSearchParams).
- The blockDim
(128,1,2)yields 256 threads per CTA, matching__launch_bounds__(256,…), and gridDim{multi_block, num_kv_heads, batch_size}should mirror the device code’s use ofnbCtaPerSMand CTA distribution.Please manually cross-check the
appendParamsequence in the HMMA branch against the device kernel signature to guarantee parameter count, order, and launch geometry are correct.
|
/bot run --disable-fail-fast |
|
PR_Github #12103 [ run ] triggered by Bot |
|
PR_Github #12103 [ run ] completed with state |
bf8aaa8 to
25d2b7d
Compare
25d2b7d to
2187f6d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #17019 [ run ] triggered by Bot |
|
PR_Github #17019 [ run ] completed with state |
c9a91ed to
523c539
Compare
|
/bot run |
|
PR_Github #17402 [ run ] triggered by Bot |
|
PR_Github #17402 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #17413 [ run ] triggered by Bot |
|
PR_Github #17413 [ run ] completed with state |
523c539 to
3acdb10
Compare
|
/bot run |
|
PR_Github #17557 [ run ] triggered by Bot |
|
PR_Github #17557 [ run ] completed with state |
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp
Outdated
Show resolved
Hide resolved
274577b to
aa2f61a
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #19216 [ run ] triggered by Bot |
|
PR_Github #19216 [ run ] completed with state |
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp
Outdated
Show resolved
Hide resolved
aa2f61a to
9d4909e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #19250 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM😀
|
PR_Github #19250 [ run ] completed with state |
9d4909e to
1e89b71
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #19385 [ run ] triggered by Bot |
|
PR_Github #19385 [ run ] completed with state |
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
38c5c63 to
ef8f199
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #19717 [ run ] triggered by Bot |
|
PR_Github #19717 [ run ] completed with state |
Description
Port precompiled XQA SPEC-DEC kernel to JIT, for faster development, skipping the to-cubin step.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Refactor
Bug Fixes