KEMBAR78
[TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel and fix a race condition in one-model spec by syuoni · Pull Request #7481 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@syuoni
Copy link
Collaborator

@syuoni syuoni commented Sep 2, 2025

[TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel and fix a race condition in one-model spec

Description

This PR introduces a new contiguousLogitsBitmaskKernel, which is more suitable for the current PyTorch GuidedDecoder:

  • Fuses d2t for EAGLE3
  • Supports token-level mask for CUDA graph

In addition, this PR fixes a race condition in one-model speculative decoding (MTP, MTP-Eagle, Eagle3):

  • In PyTorchModelEngine prepare_inputs, attn_metadata performs an async H2D copy for seq_lens -> seq_lens_cuda
  • In MTP worker, attn_metadata.seq_lens is in-place modified.

If the host code runs fast, the second operation could be earlier than the first operation. If so, the data in seq_lens_cuda is corrupted.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@syuoni syuoni self-assigned this Sep 2, 2025
@syuoni
Copy link
Collaborator Author

syuoni commented Sep 2, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17363 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17363 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13051 completed with status: 'FAILURE'

@syuoni
Copy link
Collaborator Author

syuoni commented Sep 2, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17367 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17367 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13053 completed with status: 'FAILURE'

@syuoni
Copy link
Collaborator Author

syuoni commented Sep 3, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17535 [ run ] triggered by Bot

@syuoni syuoni marked this pull request as ready for review September 3, 2025 23:29
@syuoni syuoni requested review from a team as code owners September 3, 2025 23:29
@syuoni
Copy link
Collaborator Author

syuoni commented Sep 3, 2025

/bot run --disable-fail-fast

@syuoni syuoni requested review from QiJune, lfr-0531, mikeiovine and yweng0828 and removed request for hlu1, hyukn, kris1025 and symphonylyh September 3, 2025 23:31
@tensorrt-cicd
Copy link
Collaborator

PR_Github #17578 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17535 [ run ] completed with state ABORTED

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 3, 2025

📝 Walkthrough

Walkthrough

Adds a contiguous logits bitmask CUDA kernel and launcher, updates C++/PyTorch bindings to use 2D batched tensors with optional token_mask and d2t, integrates the new op into guided decoding, and introduces speculative-decoding metadata save/restore helpers. Adjusts data types (int32 d2t), revises speculative flows, and updates tests accordingly.

Changes

Cohort / File(s) Summary
CUDA kernels: contiguous logits bitmask
cpp/tensorrt_llm/kernels/logitsBitmask.cu
Adds contiguousLogitsBitmaskKernel, dispatch/launcher, dynamic blocks-per-row, vectorized packing (float4/float2/float/BF16), explicit instantiations, namespace scoping tweak. Preserves masking semantics, supports tokenMask and d2t.
Kernel header
cpp/tensorrt_llm/kernels/logitsBitmask.h
Declares invokeContiguousLogitsBitmask(T*) alongside existing API.
THOP C++ op binding
cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
Switches API to 2D tensors (logits, bitmask) with optional tokenMask/d2t; validates shapes/dtypes/contiguity; dispatches by dtype; uses contiguous launcher and stream from logits.
Python fake op registration
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Updates trtllm::logits_bitmask signature to (logits, bitmask, token_mask=None, d2t=None).
Guided decoder integration
tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Adds token_mask host/device buffers; extends copy/apply to handle token_mask and optional dynamic bitmask length; replaces xgrammar path with torch.ops.trtllm.logits_bitmask; updates public/internal method signatures and call sites.
Attention metadata spec-dec helpers
tensorrt_llm/_torch/attention_backend/interface.py
Adds _saved_tensors store and methods prepare_for_spec_dec/restore_from_spec_dec for specified tensor fields.
Speculative loops: drafting
tensorrt_llm/_torch/speculative/drafting_loops.py
Uses prepare_for_spec_dec/restore_from_spec_dec; removes manual clones/copies of seq length tensors; centralizes restoration.
Speculative: Eagle3
tensorrt_llm/_torch/speculative/eagle3.py
Adopts spec-dec helpers; replaces last_tokens_idx with spec_metadata.gather_ids; updates prepare_1st_drafter_inputs signature and gather logic; aligns on_update ordering.
Speculative: MTP
tensorrt_llm/_torch/speculative/mtp.py
Replaces manual seq-lens management with prepare/restore helpers across change/restore/update paths; no public API changes.
Model dtype tweak
tensorrt_llm/_torch/models/modeling_speculative.py
Changes self.d2t dtype to torch.int32.
Tests for logits bitmask
tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py
Renames test, adds token_mask handling, adds d2t test with reference construction; adapts to 2D contiguous API and stripes logic.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Py as GuidedDecoder (PyTorch)
  participant Op as torch.ops.trtllm.logits_bitmask
  participant Cpp as THOP C++ Op
  participant Cu as CUDA Kernel

  Py->>Py: Build bitmask & token_mask (CPU)
  Py->>Py: Copy to GPU (bitmask, token_mask)
  Py->>Op: logits_bitmask(logits[batch,vocabP], bitmask[batch,B], token_mask?, d2t?)
  Op->>Cpp: Dispatch by dtype
  Cpp->>Cu: invokeContiguousLogitsBitmask(T*)
  Note over Cu: Applies per-batch tokenMask<br/>maps via d2t if provided<br/>masks logits using bitmask
  Cu-->>Cpp: logits updated in-place
  Cpp-->>Op: return
  Op-->>Py: return
Loading
sequenceDiagram
  autonumber
  participant C as Caller
  participant AM as AttentionMetadata
  participant SD as Speculative Flow

  C->>AM: prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda", "kv_lens_cuda")
  Note over AM: Save originals and clone working tensors
  C->>SD: Perform drafting / MTP steps
  SD->>AM: mutate working fields during run
  C->>AM: restore_from_spec_dec()
  Note over AM: Restore originals and clear saved state
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Suggested labels

Community want to contribute

Suggested reviewers

  • hlu1
  • yuxianq
  • hyukn
  • zhhuang-nv
  • shaharmor98
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py (1)

40-83: Replace unsupported out= in index_select and enforce d2t as int32 on CUDA
The call

torch.index_select(target_logits, -1, d2t_mapping, out=draft_logits)

will error—index_select doesn’t support the out= argument. Change it to:

-    d2t = torch.randint(0, 3, size=(vocab_size // 4,),
-                        device="cuda").cumsum(dim=0, dtype=torch.int32)
+    d2t = torch.randint(0, 3, size=(vocab_size // 4,), device="cuda") \
+            .cumsum(0).to(torch.int32)

@@
-    torch.index_select(target_logits, -1, d2t_mapping, out=draft_logits)
+    draft_logits.copy_(target_logits.index_select(-1, d2t_mapping))

Also add a parametrized case where d2t_mapping contains duplicate indices to validate correct handling of repeated writes/reads.

tensorrt_llm/_torch/pyexecutor/guided_decoder.py (2)

1-1: Add NVIDIA copyright header.

Coding guideline requires a 2025 NVIDIA copyright header at the top of all source files. Please prepend the repo-standard header.


256-266: Guard against None draft_tokens to avoid TypeError.

req.draft_tokens is Optional[List[int]]. Iterating when it’s None will raise. Safe-guard the loop.

-                for i, tid in enumerate(req.draft_tokens, 1):
+                for i, tid in enumerate((req.draft_tokens or ()), 1):
🧹 Nitpick comments (15)
tensorrt_llm/_torch/models/modeling_speculative.py (1)

171-175: Register non-trainable d2t as a buffer instead of a Parameter
d2t is static; registering it as a buffer excludes it from optimizer and state-dict overhead—existing .data and tensor indexing still work.
Apply:

-            self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
-                                                dtype=torch.int32),
-                                    requires_grad=False)
+            self.register_buffer(
+                "d2t",
+                torch.empty((config.draft_vocab_size,), dtype=torch.int32),
+                persistent=False,
+            )
cpp/tensorrt_llm/kernels/logitsBitmask.h (1)

32-35: Document invokeContiguousLogitsBitmask API contract

  • logits: contiguous row-major [batchSize, vocabSizePadded]
  • bitmask: [batchSize, bitmaskSize] where bitmaskSize = ceilDiv(vocabSizePadded, 32) uint32 elements
  • tokenMask: optional (nullptr allowed); when non-null, 0 at index batchIdx skips that row
  • d2t: optional (nullptr allowed); when non-null, adds per-token index delta
  • batchSize, vocabSizePadded, bitmaskSize: int32 dimensions

Replace #pragma once in logitsBitmask.h with include guards [nitpick]

tensorrt_llm/_torch/attention_backend/interface.py (1)

338-342: Guard restore against empty state and add idempotence

Avoid no-op loops if nothing was saved.

-    def restore_from_spec_dec(self) -> None:
-        for f, v in self._saved_tensors.items():
-            setattr(self, f, v)
-        self._saved_tensors.clear()
+    def restore_from_spec_dec(self) -> None:
+        if not self._saved_tensors:
+            return
+        for f, v in self._saved_tensors.items():
+            setattr(self, f, v)
+        self._saved_tensors.clear()
tensorrt_llm/_torch/speculative/drafting_loops.py (1)

133-136: Avoid .data; use the tensor directly

.data is discouraged; it bypasses safety checks. This tensor is non-trainable, so direct usage is fine.

-        if hasattr(self.draft_model.model, "d2t"):
-            d2t = self.draft_model.model.d2t.data
-            return tokens + d2t[tokens]
+        if hasattr(self.draft_model.model, "d2t"):
+            d2t = self.draft_model.model.d2t
+            return tokens + d2t[tokens]
cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp (2)

27-32: Add validation for empty batch size edge case.

Consider adding an assertion or warning for empty batch to ensure the early return is intentional and documented.

 int32_t const batchSize = logits.size(0);
 if (batchSize == 0)
 {
+    // Early return for empty batch - no work to do
     return;
 }

33-44: Consider extracting common validation logic.

The repeated validation pattern could be refactored into a helper function for better maintainability.

+namespace
+{
+void validateTensor2D(torch::Tensor const& tensor, char const* name, int32_t expectedBatchSize)
+{
+    TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor.");
+    TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous.");
+    TORCH_CHECK(tensor.dim() == 2, name, " must be a 2D tensor.");
+    TORCH_CHECK(tensor.size(0) == expectedBatchSize, name, " must have the same batch size as logits.");
+}
+} // namespace
+
 TORCH_CHECK(bitmask.size(0) == batchSize, "bitmask must have the same batch size as logits.");
 
 int32_t vocabSizePadded = logits.size(1);
 int32_t bitmaskSize = bitmask.size(1);
-TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
-TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
-TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor.");
-TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
-TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
-TORCH_CHECK(bitmask.dim() == 2, "bitmask must be a 2D tensor.");
+validateTensor2D(logits, "logits", batchSize);
+validateTensor2D(bitmask, "bitmask", batchSize);
tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py (1)

12-37: Good coverage of token_mask path; small cleanup possible.

The ref uses slicing [::stride] to mirror token_mask semantics and validates in-place behavior. Consider adding one randomized token_mask case (not only stride-based) to catch irregular masks; also you can drop the redundant re-cast since (% stride == 0) already yields bool:

-    if stride > 1:
-        token_mask = torch.arange(batch_size, dtype=torch.int32,
-                                  device="cuda") % stride == 0
-        token_mask = token_mask.to(torch.int32)
+    if stride > 1:
+        token_mask = (torch.arange(batch_size, device="cuda") % stride == 0).to(torch.int32)
tensorrt_llm/_torch/speculative/eagle3.py (1)

299-299: Guard prepare_for_spec_dec for optional CUDA fields
prepare_for_spec_dec asserts each field is a Tensor; since _seq_lens_cuda defaults to None until seq_lens is set, calling prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") can trigger an AssertionError. Replace with a guarded call:

-        attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
+        fields = ["_seq_lens"]
+        if isinstance(attn_metadata._seq_lens_cuda, torch.Tensor):
+            fields.append("_seq_lens_cuda")
+        attn_metadata.prepare_for_spec_dec(*fields)
cpp/tensorrt_llm/kernels/logitsBitmask.cu (2)

257-292: Static SM count caching may be wrong across multi-GPU contexts.

static int const smCount = getMultiProcessorCount(); is process-static. If device changes between calls, grid sizing can be off. Prefer querying per-call or caching per-device id.

-    static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
+    int const smCount = tensorrt_llm::common::getMultiProcessorCount();

If you want caching, key it by cudaGetDevice() id.


51-63: Packed reinterpret_cast relies on alignment; consider safer packing.

reinterpret_cast<PackedT*> of stack arrays can violate alignment rules and hamper SASS generation on some archs. Use a pack helper that writes through a PackedT lvalue:

-template <typename T, typename PackedT>
-__device__ PackedT packedNegativeInfinity()
+template <typename T, typename PackedT>
+__device__ PackedT packedNegativeInfinity()
 {
-    int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
-    T packed[kAlignment];
-#pragma unroll
-    for (int i = 0; i < kAlignment; i++)
-    {
-        packed[i] = negativeInfinity<T>();
-    }
-    return *reinterpret_cast<PackedT*>(packed);
+    int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
+    PackedT v;
+    T* elems = reinterpret_cast<T*>(&v);
+#pragma unroll
+    for (int i = 0; i < kAlignment; i++) { elems[i] = negativeInfinity<T>(); }
+    return v;
 }

Similarly, consider using the same pattern where you load/store via PackedT to avoid potential misalignment penalties.

Also applies to: 100-114

tensorrt_llm/_torch/pyexecutor/guided_decoder.py (5)

140-141: Confirm token_mask dtype contract (int32) with the custom op.

If the kernel accepts narrower or bool masks, consider switching to torch.bool/torch.int8 to cut memory bandwidth. Otherwise, add a brief comment stating int32 is required by torch.ops.trtllm.logits_bitmask.


273-282: Add fast-path and sanity checks in _copy_bitmask.

Short-circuit when there’s nothing to copy and assert basic shape/dtype invariants to catch integration errors early.

 def _copy_bitmask(self,
                   requests: GuidedRequests,
                   num_bitmask_tokens: Optional[int] = None) -> None:
     if num_bitmask_tokens is None:
         num_bitmask_tokens = requests.num_bitmask_tokens
+    if num_bitmask_tokens == 0:
+        return
+    # Basic invariants (cheap and helpful in CI)
+    assert self.bitmask.dtype == self.bitmask_dtype
+    assert self.token_mask.dtype == self.token_mask_dtype
+    assert self.bitmask_host.size(0) >= num_bitmask_tokens
+    assert self.token_mask_host.size(0) >= num_bitmask_tokens
     self.bitmask[:num_bitmask_tokens].copy_(
         self.bitmask_host[:num_bitmask_tokens], non_blocking=True)
     self.token_mask[:num_bitmask_tokens].copy_(
         self.token_mask_host[:num_bitmask_tokens], non_blocking=True)

287-299: Gracefully handle op availability and empty work; validate shapes.

Add early return for zero rows, verify the custom op exists, and assert batch-row alignment to fail fast on mismatches.

 @torch.inference_mode()
 def _apply_bitmask(self,
                    requests: GuidedRequests,
                    logits: torch.Tensor,
                    d2t: Optional[torch.Tensor] = None,
                    num_bitmask_tokens: Optional[int] = None) -> None:
     if num_bitmask_tokens is None:
         num_bitmask_tokens = requests.num_bitmask_tokens
+    if num_bitmask_tokens == 0:
+        return
+    # Ensure op is registered
+    if not hasattr(torch.ops, "trtllm") or not hasattr(torch.ops.trtllm, "logits_bitmask"):
+        raise RuntimeError("trtllm.logits_bitmask custom op is not available. Ensure extensions are built/loaded.")
+    # Basic shape checks
+    assert logits.size(0) >= num_bitmask_tokens, (
+        f"logits rows {logits.size(0)} < num_bitmask_tokens {num_bitmask_tokens}")
+    assert self.bitmask.size(0) >= num_bitmask_tokens
+    assert self.token_mask.size(0) >= num_bitmask_tokens
     torch.ops.trtllm.logits_bitmask(
         logits[:num_bitmask_tokens],
         self.bitmask[:num_bitmask_tokens],
         token_mask=self.token_mask[:num_bitmask_tokens],
         d2t=d2t)

If you want, I can push a follow-up commit with these guards.


311-313: Public API: briefly document num_bitmask_tokens override.

Add a one-line docstring explaining when/why callers should pass num_bitmask_tokens (e.g., CUDA stream timing during drafting).


317-323: Public API: include d2t/num_bitmask_tokens in docstring.

The method’s behavior changed; reflect the new parameters in the docstring for discoverability.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5ff3a65 and f3baa56.

📒 Files selected for processing (11)
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu (3 hunks)
  • cpp/tensorrt_llm/kernels/logitsBitmask.h (1 hunks)
  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp (2 hunks)
  • tensorrt_llm/_torch/attention_backend/interface.py (3 hunks)
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py (7 hunks)
  • tensorrt_llm/_torch/speculative/drafting_loops.py (2 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (4 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (4 hunks)
  • tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (7)
**/*

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Filenames compiled into a target must be case-insensitively unique

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.h
  • tensorrt_llm/_torch/speculative/mtp.py
  • tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use spaces, not tabs; indent 4 spaces

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.h
  • tensorrt_llm/_torch/speculative/mtp.py
  • tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs (Python)
Maintain module namespace on import: prefer from package.subpackage import foo; use foo.Symbol()
Python filenames use snake_case
Python class names use PascalCase
Python functions and methods use snake_case
Python local variables use snake_case; if starting with a number concept, prefix with k (e.g., k_99th_percentile)
Python global variables use G_ prefix with UPPER_SNAKE_CASE
Python constants use UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes
Initialize all externally visible class members in init
For public interfaces, prefer docstrings over comments; comments should be for in-function or file-local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes and variables inline with docstrings immediately after assignment
Avoid reflection when a non-reflective approach suffices
Limit except clauses to specific exceptions where possible
When using try/except for duck-typing, keep try body minimal and move logic to else

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
**/*.{cpp,cc,cxx,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.h
  • tensorrt_llm/_torch/speculative/mtp.py
  • tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/pyexecutor/guided_decoder.py
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh}: Closing braces of C++ namespaces must include a comment naming the namespace (e.g., } // namespace foo)
Avoid using literals (except 0, nullptr, true, false) directly in logic; use named constants for comparisons
Use Allman brace style in C++
Place semicolon of empty for/while loop on its own line
Use brace-delimited statements for bodies of switch/while/do/for and always brace if/else bodies
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Non-static, externally visible globals use g prefix with lowerCamelCase (e.g., gDontUseGlobalFoos)
Static or anonymous-namespace globals use s prefix with lowerCamelCase (e.g., sMutableStaticGlobal)
Locally visible static variables use s prefix (e.g., static std::once_flag sFlag)
Member variables use m prefix with CamelCase (public may omit but encouraged)
Constants (enums, globals, static consts, function-scope magic numbers) use k prefix with UPPER_SNAKE (e.g., kDIGIT_NUM)
Function-scope non-literal, non-magic constants use normal non-const naming (e.g., const bool pass)
If macros are necessary, name them in UPPER_SNAKE_CASE
Avoid Hungarian notation except allowed app’s hungarian like nb for counts
Constructor parameters conflicting with member names get a trailing underscore (e.g., foo_)
Use uppercase literal suffixes (e.g., 1234L not 1234l)
Format C++ with clang-format (LLVM style), max line length 120; justify any exceptions with clang-format off/on blocks
Use C++-style comments; C comments not allowed except special inline cases; single-line comments use //
Use inline parameter comments in calls when arguments aren’t obvious (e.g., /* checkForErrors = / false)
Disable code with #if/#endif (optionally mnemonic conditions or no-op macros); do not comment out code; avoid dead code
Use the least forceful C++ cast; avoid removing const/volatile; avoid C-style and functional casts (except explicit constructors); cast void
to T* with static_cas...

Files:

  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • cpp/tensorrt_llm/kernels/logitsBitmask.h
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
**/*.{cc,cpp,cxx,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cc,cpp,cxx,cu}: Prefer const or constexpr variables over #define for constants in C++
Declare variables const if not modified after initialization
Use smart pointers for heap allocation; prefer unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only exceptionally; avoid deprecated smart pointers
Avoid declaring large functions inline unless there’s a quantifiable benefit; remember in-class definitions are implicitly inline
Every defined function must be referenced at least once; avoid unused methods

Files:

  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx}: Prefer const or constexpr over #define for constants in C++ headers
Use Doxygen for documenting interfaces; use //! for comments and //!< for member annotations in C++
Use include guards in headers with symbol format TRTLLM__H (no underscores prefix/suffix; filename only)

Files:

  • cpp/tensorrt_llm/kernels/logitsBitmask.h
🧠 Learnings (1)
📚 Learning: 2025-09-02T13:42:44.885Z
Learnt from: pcastonguay
PR: NVIDIA/TensorRT-LLM#7455
File: tensorrt_llm/_torch/pyexecutor/py_executor.py:1852-1860
Timestamp: 2025-09-02T13:42:44.885Z
Learning: In MPI communication within TensorRT-LLM pipeline parallelism, different communication types (tokens, logits, termination sync) must use disjoint tag namespaces to avoid message routing collisions when using the same source/destination patterns.

Applied to files:

  • cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
  • cpp/tensorrt_llm/kernels/logitsBitmask.h
  • cpp/tensorrt_llm/kernels/logitsBitmask.cu
🧬 Code graph analysis (9)
cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp (1)
cpp/tensorrt_llm/kernels/logitsBitmask.cu (5)
  • invokeContiguousLogitsBitmask (295-319)
  • invokeContiguousLogitsBitmask (295-296)
  • invokeContiguousLogitsBitmask (321-322)
  • invokeContiguousLogitsBitmask (323-324)
  • invokeContiguousLogitsBitmask (327-329)
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
tensorrt_llm/_torch/attention_backend/interface.py (2)
  • prepare_for_spec_dec (330-336)
  • restore_from_spec_dec (338-341)
tensorrt_llm/_torch/attention_backend/interface.py (1)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • clear (299-310)
cpp/tensorrt_llm/kernels/logitsBitmask.h (1)
cpp/tensorrt_llm/kernels/logitsBitmask.cu (8)
  • T (34-37)
  • void (66-115)
  • void (185-255)
  • invokeContiguousLogitsBitmask (295-319)
  • invokeContiguousLogitsBitmask (295-296)
  • invokeContiguousLogitsBitmask (321-322)
  • invokeContiguousLogitsBitmask (323-324)
  • invokeContiguousLogitsBitmask (327-329)
tensorrt_llm/_torch/speculative/mtp.py (2)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • attn_metadata (68-69)
tensorrt_llm/_torch/attention_backend/interface.py (3)
  • prepare_for_spec_dec (330-336)
  • restore_from_spec_dec (338-341)
  • on_update (157-167)
tests/unittest/_torch/thop/parallel/test_logits_bitmask_op.py (1)
tensorrt_llm/functional.py (3)
  • arange (1498-1569)
  • cumsum (2411-2536)
  • index_select (2216-2275)
tensorrt_llm/_torch/speculative/eagle3.py (2)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
  • attn_metadata (68-69)
  • spec_metadata (60-61)
tensorrt_llm/_torch/attention_backend/interface.py (5)
  • prepare_for_spec_dec (330-336)
  • num_contexts (208-209)
  • num_contexts (212-215)
  • restore_from_spec_dec (338-341)
  • on_update (157-167)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)
tensorrt_llm/_utils.py (1)
  • nvtx_range (850-869)
cpp/tensorrt_llm/kernels/logitsBitmask.cu (1)
cpp/include/tensorrt_llm/common/cudaUtils.h (1)
  • getMultiProcessorCount (393-400)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (18)
tensorrt_llm/_torch/attention_backend/interface.py (3)

6-7: Import update is fine


144-146: Private stash for saved tensors looks good


330-337: Torch isn’t available in this environment—please verify manually that tensor.clone() preserves pinned memory. For example, in a Python REPL with PyTorch installed:

import torch
x = torch.empty(4, pin_memory=True)
print("pinned_before", x.is_pinned())  # expect True
y = x.clone()
print("pinned_after", y.is_pinned())   # should also be True
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)

164-169: LGTM!

The signature change from lists of tensors to single tensors with optional parameters is appropriate and aligns with the new contiguous kernel implementation.

tensorrt_llm/_torch/speculative/mtp.py (4)

883-883: LGTM!

Using the new prepare_for_spec_dec method properly encapsulates the tensor preservation logic for speculative decoding.


907-908: LGTM!

Proper use of the corresponding restore_from_spec_dec method to restore saved tensors.


1171-1171: LGTM!

Consistent application of the spec-dec preparation pattern for the MTP Eagle worker flow.


1283-1284: LGTM!

Properly restores the saved tensors and updates the metadata state after speculative decoding operations.

cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp (5)

24-26: LGTM!

The function signature properly uses at::optional for optional parameters, following PyTorch conventions correctly.


45-54: LGTM!

Proper handling of optional tokenMask parameter with appropriate validation.


56-65: LGTM!

Correct validation and handling of the optional d2t parameter.


69-93: LGTM!

Clean dtype dispatching with appropriate error handling for unsupported types.


100-100: LGTM!

The Torch binding definition correctly reflects the new function signature with optional parameters.

tensorrt_llm/_torch/speculative/eagle3.py (1)

387-389: Restore-on-update ordering is correct.

Calling restore_from_spec_dec() before on_update() maintains metadata invariants for downstream consumers.

cpp/tensorrt_llm/kernels/logitsBitmask.cu (1)

184-197: Early-exit on tokenMask is efficient and correct.

Skipping entire rows when token is disabled saves bandwidth and matches semantics.

tensorrt_llm/_torch/pyexecutor/guided_decoder.py (3)

178-185: Good addition: dedicated token_mask device/host buffers.

Shape matches bitmask rows and pinned host memory enables async copies. Looks correct.


213-214: Zeroing only the active span is efficient.

Clearing token_mask_host[:num_bitmask_tokens] up front avoids stale flags without touching the full buffer.


537-546: Verify the override to len(self.requests) during drafting.

Confirm that len(self.requests) always equals the active row count produced by build() on requests_hostfunc in all draft steps. If there’s any chance of divergence (e.g., filtered/invalid requests), prefer using len(self.requests_hostfunc) here for consistency with the builder.

-            self.copy_bitmask(num_bitmask_tokens=len(self.requests))
+            self.copy_bitmask(num_bitmask_tokens=len(self.requests_hostfunc))
...
-        self.apply_bitmask(logits,
-                           d2t=d2t,
-                           num_bitmask_tokens=len(self.requests))
+        self.apply_bitmask(logits,
+                           d2t=d2t,
+                           num_bitmask_tokens=len(self.requests_hostfunc))

@syuoni syuoni changed the title [TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel [TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel and fix a race condition in one-model spec Sep 4, 2025
Copy link
Collaborator

@lfr-0531 lfr-0531 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attn_metadata changes LGTM~

@syuoni syuoni requested a review from litaotju September 4, 2025 03:31
syuoni and others added 7 commits September 4, 2025 04:12
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
@syuoni
Copy link
Collaborator Author

syuoni commented Sep 4, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17614 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17578 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17614 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13241 completed with status: 'SUCCESS'

Copy link
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the data race; Python side changes look good to me

@syuoni syuoni merged commit 1745102 into NVIDIA:main Sep 4, 2025
5 checks passed
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…ndition in one-model spec (NVIDIA#7481)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants