KEMBAR78
[TRTLLM-7153] [feat] Move stop_criteria to sample_async by netanel-haber · Pull Request #7041 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@netanel-haber
Copy link
Collaborator

@netanel-haber netanel-haber commented Aug 19, 2025

Moving computing finish reasons for a request batch from cpu handle_requests to gpu sample_async, in TorchSampler, for all but one code path.

Perf results : `Llama-3.1-8B-Instruct`

bench.sh
profiles.zip
Sanpshot of profiles.zip: (observe the similarity, and that the overlapping is preserved in branch.nsys-rep, and no syncs in sample async).
image

Model:                  llama-3.1-model/Llama-3.1-8B-Instruct/
Model Path:             /home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct
Number of requests:             256
Number of concurrent requests:  206.0179
Average Input Length (tokens):  40.0000
Average Output Length (tokens): 2048.0000
Max Runtime Batch Size: 256
Max Runtime Tokens:     131072
Scheduling Policy:      GUARANTEED_NO_EVICT
KV Memory Percentage:   90.00%

BRANCH:
  Request Throughput (req/sec):                     3.0419
  Total Output Throughput (tokens/sec):             6229.7218
  Total Token Throughput (tokens/sec):              6351.3960
  Total Latency (ms):                               84159.1355
  Average request latency (ms):                     67727.6734
  Per User Output Throughput [w/ ctx] (tps/user):   30.5880
  Per GPU Output Throughput (tps/gpu):              6229.7218

MAIN:
  Request Throughput (req/sec):                     3.0536
  Total Output Throughput (tokens/sec):             6253.8716
  Total Token Throughput (tokens/sec):              6376.0176
  Total Latency (ms):                               83834.1480
  Average request latency (ms):                     67517.2946
  Per User Output Throughput [w/ ctx] (tps/user):   30.6809
  Per GPU Output Throughput (tps/gpu):              6253.8716

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 19, 2025

📝 Walkthrough

Walkthrough

Introduces Torch-backed stores (MTPStore/TorchStore), refactors MTPSampler/TorchSampler to use external stores, adds single-beam stop utilities in sampler_utils, threads FinishReason tensors through sampling, and adds a unit test for finish-reason writing.

Changes

Cohort / File(s) Summary
Speculative sampler refactor
tensorrt_llm/_torch/speculative/mtp.py
MTPSampler now uses external MTPStore created with max_draft_len, max_num_sequences, max_beam_width; removed inner Store dataclass and super().__init__ usage; added assertion for request.py_seq_slot; replaced per-beam logic with BEAM_0 and handle_stop_1_beam; sampler stores max_seq_len.
Torch backend sampler & store
tensorrt_llm/_torch/pyexecutor/sampler.py
Added TorchStore (token buffers + finish_reasons), removed inner Store and create_store; introduced SampleStateTensorsHostTorch / SampleStateTorch; sampler now constructs TorchStore, computes seq_slots_host/seq_slots, uses SINGLE_BEAM_WIDTH/BEAM_0, threads finish-reason computation through sampling, and returns host finish_reasons.
Sampling utilities (single-beam)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py
New module adding BEAM_0, SINGLE_BEAM_WIDTH and functions max_token_criteria_1_beam, stop_token_criteria, handle_stop_1_beam to evaluate stop/finish conditions and set FinishReason for single-beam sampling.
Tests for finish reasons
tests/unittest/_torch/test_torch_sampler.py
New unit test test_write_finish_reasons constructing varied LlmRequest cases and asserting TorchSampler._write_finish_reasons populates store.finish_reasons as expected (uses CUDA tensors, checks BEAM_0).
Public imports / symbols updated
tensorrt_llm/_torch/speculative/mtp.py, tensorrt_llm/_torch/pyexecutor/sampler.py
Added/used imports Sampler, TorchStore, MTPStore, BEAM_0, SINGLE_BEAM_WIDTH, handle_stop_1_beam; adjusted indexing/shape conventions to single-beam semantics across modules.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Client
  participant MTPS as MTPSampler
  participant Store as MTPStore
  participant Req as LlmRequest
  participant Utils as sampler_utils
  Note right of Store #ddeeff: Store holds token buffers\nand finish_reasons tensors

  Client->>MTPS: init(..., max_draft_len, max_num_sequences, max_beam_width)
  MTPS->>Store: MTPStore.__init__(...)
  Client->>MTPS: sample_async(requests)
  loop per-step
    MTPS->>MTPS: generate new_token(s)
    MTPS->>Store: write next token(s)
    MTPS->>Req: add_token(new_token)
    MTPS->>Utils: handle_stop_1_beam(req, new_token, max_seq_len)
    alt stop == True
      MTPS->>Store: write finish_reasons
      MTPS-->>Client: return finished SampleState (host.finish_reasons)
    else continue
    end
  end
Loading
sequenceDiagram
  autonumber
  participant User
  participant TorchSampler
  participant TorchStore
  participant Utils as sampler_utils
  Note right of TorchStore #ddeeff: TorchStore: new_tokens, finish_reasons

  User->>TorchSampler: init(args...)
  TorchSampler->>TorchStore: TorchStore.__init__(...)
  User->>TorchSampler: sample_async(requests)
  loop per-step
    TorchSampler->>TorchSampler: compute new_tokens
    TorchSampler->>TorchStore: store new_tokens (device tensors)
    TorchSampler->>Utils: handle_stop_1_beam(request, new_token, max_seq_len)
    alt stop
      TorchSampler->>TorchStore: write finish_reasons
      TorchSampler-->>User: return SampleStateTorch (host.finish_reasons)
    else continue
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • dcampora
  • kris1025
  • DomBrown
  • achartier
  • yuxianq
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tensorrt_llm/_torch/speculative/mtp.py (5)

1-1: Add mandatory NVIDIA copyright header (2025) at file top.

Per coding guidelines, all source files must prepend the current-year NVIDIA copyright header.

Apply at the very top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
 from dataclasses import dataclass

613-616: Bug: spec_metadata.slot_ids no longer exists; use manager.slot_ids.

After migrating slot-id storage to the manager, this reference will raise AttributeError and break updates.

-            slot_ids = spec_metadata.slot_ids[:batch_size]
+            slot_ids = spec_metadata.mtp_hidden_states_manager.slot_ids[:batch_size]
             mtp_tokens = mtp_past_tokens_pool[slot_ids]
             mtp_hidden_states = mtp_past_hidden_states_pool[slot_ids]

808-823: Bug: Use manager.slot_ids for relaxed acceptance; spec_metadata.slot_ids is gone.

Both the delta write and the op call must consume the manager-managed slot-ids.

-            ctx_slot_ids = spec_metadata.slot_ids[:num_contexts]
+            ctx_slot_ids = spec_metadata.mtp_hidden_states_manager.slot_ids[:num_contexts]
             mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta)
@@
-            accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
-                spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens,
+            accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
+                spec_metadata.mtp_hidden_states_manager.slot_ids, topk_value, topk_indices, draft_tokens,
                 mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens,
                 mtp_num_modules, batch_size, num_contexts,
                 self.spec_config.relaxed_topk, self.spec_config.relaxed_delta,
                 self.spec_config.BEGIN_THINKING_PHASE_TOKEN,
                 self.spec_config.END_THINKING_PHASE_TOKEN)

972-980: Bug: THOP path still references spec_metadata. pointer tensors; switch to manager.*

These attributes were migrated; calling the op with stale fields will fail at runtime.

             (return_input_ids, return_hidden_states
              ) = torch.ops.trtllm.mtp_prepare_drafter_inputs_op(
                  input_ids, attn_metadata.seq_lens_cuda,
-                 spec_metadata.mtp_hidden_states_ptrs,
-                 spec_metadata.mtp_past_tokens_ptrs, hidden_states,
+                 spec_metadata.mtp_hidden_states_manager.hidden_states_ptrs,
+                 spec_metadata.mtp_hidden_states_manager.past_tokens_ptrs, hidden_states,
                  accepted_tokens, num_accepted_tokens, return_input_ids,
                  return_hidden_states, mtp_num_modules, batch_size,
                  num_contexts, hidden_size)

1001-1012: Bug: Use manager.slot_ids when assembling generation drafter inputs.

This path selects per-request history by slot; spec_metadata no longer owns slot_ids.

-                slot_ids = spec_metadata.slot_ids[num_contexts:batch_size]
+                slot_ids = spec_metadata.mtp_hidden_states_manager.slot_ids[num_contexts:batch_size]
                 gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
                 gen_token_idx = num_accepted_tokens[num_contexts:] - 1
                 accepted_tokens_gen = accepted_tokens[num_contexts:, :]
                 input_ids_gen = accepted_tokens_gen[gen_batch_idx,
                                                     gen_token_idx].unsqueeze(1)
                 input_ids_gen = torch.concat(
                     [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen],
                     dim=1)
                 hidden_states_gen = mtp_past_hidden_states_pool[
                     slot_ids].flatten(0, 1)
🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/mtp.py (1)

266-311: PR objective not reflected: stop_criteria still enforced in update_requests, not sample_async.

Title says "Move stop_critera to sample_async" but stop handling remains in update_requests via _handle_stop_criteria. If the intent is to shift stop checks into sample_async (for earlier host-side decisioning or overlap), please confirm the desired behavior and update accordingly to avoid double-processing.

Would you like me to sketch the minimal move (calling _handle_stop_criteria per-request as soon as device tensors are placed, using outputs['new_tokens']/['new_tokens_lens']), or a fuller refactor that returns per-request stop flags from sample_async to drive the scheduler?

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1966730 and ceafaf2.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/speculative/mtp.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/mtp.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/mtp.py
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/mtp.py (3)

600-607: Good: Centralized pointer usage guarded by manager presence.

The assertion plus passing manager.hidden_states_ptrs and manager.past_tokens_ptrs into the op aligns with the migration to MTPHiddenStatesManager.


784-785: Good: Guard relaxed-acceptance path behind manager presence.

This ensures the delta pool is available before use.


953-957: Good: Manager asserted and pools sourced from manager.

This is consistent with the centralization effort.

@netanel-haber netanel-haber force-pushed the user/nhaber/feature/trtllm-7153-move-stop-critera-to-sample-async branch from ceafaf2 to e3d1d9d Compare August 19, 2025 13:50
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/speculative/mtp.py (2)

271-280: Range expects int, not a tensor

num_new_tokens is a 0-d tensor; range(num_new_tokens) will raise TypeError. Convert to int.

Apply:

-            num_new_tokens = new_tokens_lens[req.py_seq_slot]
+            num_new_tokens = int(new_tokens_lens[req.py_seq_slot].item())
             for i in range(num_new_tokens):
                 new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
                 if self._handle_stop_criteria(req, new_token):
                     break

169-198: Critical: Align mtp_slot_ids dtype with manager.slot_ids (torch.long)

Found that in tensorrt_llm/_torch/speculative/mtp.py the temporary mtp_slot_ids tensor is created with dtype=torch.int (32-bit) but is later copied into manager.slot_ids (64-bit). This mismatch will break copy_ or cause silent truncation.

• File: tensorrt_llm/_torch/speculative/mtp.py
Lines 194–196: change dtype from torch.int to torch.long

Suggested diff:

- mtp_slot_ids = torch.tensor(mtp_slot_ids,
-                             dtype=torch.int,
-                             pin_memory=True)
+ mtp_slot_ids = torch.tensor(mtp_slot_ids,
+                             dtype=torch.long,
+                             pin_memory=True)

Please update this and verify any other producers/consumers (e.g., in custom ops or host-side tensor builds) use torch.long for slot IDs.

♻️ Duplicate comments (1)
tensorrt_llm/_torch/speculative/mtp.py (1)

169-198: Unify slot_ids dtype when copying into manager.slot_ids

You build mtp_slot_ids with dtype=torch.int (int32) and then copy_ into manager.slot_ids (int64). copy_ requires matching dtypes and will error.

Apply:

-            mtp_slot_ids = torch.tensor(mtp_slot_ids,
-                                        dtype=torch.int,
-                                        pin_memory=True)
+            mtp_slot_ids = torch.tensor(mtp_slot_ids,
+                                        dtype=torch.long,
+                                        pin_memory=True)
             manager.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

1-1: Missing NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to this source file.

Would you like me to generate the header block?

tensorrt_llm/_torch/speculative/mtp.py (2)

1-1: Missing NVIDIA copyright header

Per repository guidelines, prepend the current-year NVIDIA copyright header to this source file.


200-212: Consolidate stop-criteria logic to avoid duplication across samplers

Consider hoisting shared stop-criteria helpers into the Sampler base (with a required max_seq_len property) to avoid code duplication between TorchSampler and MTPSampler.

I can prepare a follow-up patch that introduces default implementations in Sampler and updates both subclasses.

Also applies to: 221-229, 229-246, 271-280

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ceafaf2 and e3d1d9d.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (12 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (6)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

338-345: Samplers now delegate storage to TorchStore — good direction

Using a public store object and exposing sizing/beam constants from it simplifies responsibilities and improves reusability.

tensorrt_llm/_torch/speculative/mtp.py (5)

10-14: Importing TorchStore/torch sampler utilities here is fine, but watch for layering

This introduces a dependency from speculative/ to pyexecutor.sampler. Keep an eye on cycles and public API stability if these move again.

If you want, I can scan the tree for cycles that include this module.


63-78: Pointer/slot-id buffers: correct dtypes for pointers and indices

  • hidden_states_ptrs/past_tokens_ptrs use int64 on CUDA — correct for data_ptr() consumption in custom ops.
  • slot_ids is torch.long on CUDA — appropriate for indexing and passing to ops.

LGTM.


289-307: Indexing into preallocated store tensors — looks correct

  • Slots are long on CUDA
  • Shapes match: (max_tokens, max_num_sequences, 1) -> squeeze(-1).T => (max_num_sequences, max_tokens)

LGTM.


616-623: Manager-backed pointers fed to THOP op — correct types and order

Passing manager.hidden_states_ptrs/past_tokens_ptrs (int64) aligns with .data_ptr() producers. Good.


799-839: Relaxed-acceptance path correctly sources slot_ids/delta from manager

The acceptance kernel receives slot_ids and per-slot relaxed deltas from the manager — matches the new centralized storage design.

@netanel-haber netanel-haber force-pushed the user/nhaber/feature/trtllm-7153-move-stop-critera-to-sample-async branch from beb339f to ed7dfb8 Compare August 25, 2025 15:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (4)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

311-321: TorchStore: persist sizing fields and fix stray attribute docstring.

Downstream code (e.g., MTPStore) expects max_draft_len/max_num_sequences/max_beam_width on the store. Also, the triple-quoted string after new_tokens is a stray no-op string literal.

 class TorchStore:
 
     def __init__(self, *, max_draft_len: int, max_num_sequences: int,
                  max_beam_width: int):
-        self.max_tokens = max_draft_len + 1
+        # Persist sizing for downstream stores (e.g., MTPStore)
+        self.max_draft_len = max_draft_len
+        self.max_num_sequences = max_num_sequences
+        self.max_beam_width = max_beam_width
+        self.max_tokens = max_draft_len + 1
         assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
         self.new_tokens = int_tensor(
             (self.max_tokens, max_num_sequences, max_beam_width))
-        """Shape: See cpp DecoderState.getAllNewTokens()"""
+        # Shape: see cpp DecoderState.getAllNewTokens()
         self.finish_reasons = int_tensor(self.new_tokens.shape)
tensorrt_llm/_torch/speculative/mtp.py (3)

169-198: slot_ids dtype mismatch (int32 -> int64) will break copy_.

manager.slot_ids is torch.long. mtp_slot_ids is created as torch.int (int32) and then copy_ called → runtime error. Produce mtp_slot_ids as long.

-            mtp_slot_ids = torch.tensor(mtp_slot_ids,
-                                        dtype=torch.int,
-                                        pin_memory=True)
+            mtp_slot_ids = torch.tensor(mtp_slot_ids,
+                                        dtype=torch.long,
+                                        pin_memory=True)
             manager.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)

221-227: Persist max_seq_len on MTPSampler for stop criteria.

_handle_stop_criteria (below) references max_seq_len via helper; store args.max_seq_len here.

     def __init__(self, args: TorchSampler.Args, *, nextn: int):
         self.mapping = None
         self.draft_len = nextn
         self.store = MTPStore(max_draft_len=nextn,
                               max_num_sequences=args.max_num_sequences,
                               max_beam_width=args.max_beam_width)
+        self.max_seq_len = args.max_seq_len

228-245: Undefined helpers in MTPSampler._handle_stop_criteria → AttributeError.

This method calls self._meet_max_token_stop_criteria and self._meet_stop_token_criteria, neither defined in MTPSampler nor its base. Implement them mirroring TorchSampler.

 class MTPSampler(Sampler):
@@
     def _handle_stop_criteria(self, request: LlmRequest,
                               new_token: int) -> bool:
@@
         return False
+
+    def _meet_max_token_stop_criteria(self, request: LlmRequest) -> bool:
+        num_tokens = request.get_num_tokens(BEAM_0)
+        return ((num_tokens - request.py_orig_prompt_len) >= request.py_max_new_tokens
+                or (num_tokens >= self.max_seq_len))
+
+    @staticmethod
+    def _meet_stop_token_criteria(py_stop_words_list, tokens: list[int]) -> bool:
+        if py_stop_words_list:
+            assert isinstance(py_stop_words_list, list), \
+                "request.py_stop_words_list should be a list"
+            stop_words_list, prefix_sum = py_stop_words_list
+            offset = 0
+            for i, offset_end in enumerate(prefix_sum):
+                if i > 0:
+                    offset = prefix_sum[i - 1]
+                stop_word = stop_words_list[offset:offset_end]
+                if len(stop_word) <= len(tokens) and tokens[-len(stop_word):] == stop_word:
+                    return True
+        return False
🧹 Nitpick comments (11)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)

1-1: Add NVIDIA copyright header (2025) at file top.

All .py/.cpp/.cu files must carry the 2025 NVIDIA copyright header per guidelines.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

479-496: Consistently use py_seq_slot in Python path.

These lines use request.seq_slot while the rest of the module (and add_token) use py_seq_slot. For consistency and to avoid surprises if the C++-backed attribute differs, prefer py_seq_slot here too.

-            new_tokens[i, request.seq_slot, BEAM_0] = new_token
+            new_tokens[i, request.py_seq_slot, BEAM_0] = new_token
             request.add_new_token(new_token, BEAM_0)
@@
-            new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token
+            new_tokens[num_accepted, request.py_seq_slot, BEAM_0] = new_token

647-661: Avoid numpy for trivial stop-words length; use pure Python to reduce deps.

np.diff/np.max for short lists adds import overhead and CPU<->GPU copies aren’t involved here. A simple Python scan is enough.

-    def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
-        max_stop_word_len = 0
-        for req in requests:
-            if req.py_stop_words_list is None:
-                continue
-            _, cumsum = req.py_stop_words_list
-            if -1 in cumsum:
-                cumsum = cumsum[:cumsum.index(-1)]
-            request_max_stop_word_len = np.max(np.diff(cumsum, prepend=0),
-                                               initial=0)
-            max_stop_word_len = max(max_stop_word_len,
-                                    request_max_stop_word_len)
-        return max_stop_word_len
+    def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
+        max_len = 0
+        for req in requests:
+            if req.py_stop_words_list is None:
+                continue
+            _, cumsum = req.py_stop_words_list
+            if -1 in cumsum:
+                cumsum = cumsum[:cumsum.index(-1)]
+            prev = 0
+            for end in cumsum:
+                max_len = max(max_len, end - prev)
+                prev = end
+        return max_len

756-756: Lint: break long boolean expression.

Ruff E501 flagged this line. Simple wrap improves readability.

-        fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
+        fast_path = (
+            not self.enable_mixed_sampler
+            and no_draft_tokens
+            and gen_logits_host is None
+            and log_probs_host is None
+        )

361-375: Deterministic generator per device is good; consider multi-device guard.

get_generator caches a single Generator; if logits arrive from a different device later, you’ll reuse a generator on the wrong device. Optional: memoize by device.

-        if self._generator is None:
-            # Fallback to a default seed if not set
-            self._generator = torch.Generator(device=device)
-            self._generator.manual_seed(self._global_seed)
-        return self._generator
+        if self._generator is None or self._generator.device != device:
+            self._generator = torch.Generator(device=device)
+            self._generator.manual_seed(self._global_seed)
+        return self._generator
tests/unittest/_torch/test_torch_sampler.py (2)

1-1: Add NVIDIA copyright header (2025) at file top.

Per guidelines, include the header in test modules too.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+...

41-106: Nice targeted test for finish reasons; consider adding precedence case.

You already cover END_ID, LENGTH, STOP_WORDS, NOT_FINISHED. Add a case where multiple criteria could fire at the same step (e.g., both LENGTH and END_ID) to assert END_ID wins.

@@
 def test_write_finish_reasons():
@@
     for actual, request in zip(actual_finish_reasons, requests, strict=True):
         expected = request.finish_reasons
         msg = f"""\
 actual={[FinishReason(reason) for reason in actual]} != expected={expected}
 For request: {request.request.request_id=}, {request.input_tokens=}, {request.new_tokens=}
 """
         assert actual == [reason.value for reason in expected], msg
+
+def test_finish_reason_precedence_end_id_over_length():
+    sampler_args = TorchSampler.Args(max_seq_len=4,
+                                     max_draft_len=1,
+                                     max_num_sequences=1,
+                                     max_beam_width=1,
+                                     enable_mixed_sampler=False)
+    sampler = TorchSampler(args=sampler_args)
+    END_ID = 7
+    req = Request([1, 2, 3],  # prompt_len = 3
+                  end_id=END_ID,
+                  max_new_tokens=1,   # LENGTH would trigger after 1 new token
+                  new_tokens=[END_ID, 0],
+                  finish_reasons=[FinishReason.END_ID, FinishReason.NOT_FINISHED])
+    new_tokens = torch.tensor([req.new_tokens], dtype=torch.int32, device="cuda").T.unsqueeze(-1)
+    seq_slots = torch.tensor([req.request.py_seq_slot], device="cuda", dtype=torch.long)
+    sampler._write_finish_reasons([req.request],
+                                  finish_reasons=sampler.store.finish_reasons,
+                                  new_tokens=new_tokens,
+                                  seq_slots=seq_slots)
+    actual = sampler.store.finish_reasons[:, seq_slots, BEAM_0].T.tolist()[0]
+    assert actual[0] == FinishReason.END_ID.value
tensorrt_llm/_torch/speculative/mtp.py (4)

1-1: Add NVIDIA copyright header (2025) at file top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+...

10-14: Import FinishReason from bindings for consistency; also import SINGLE_BEAM_WIDTH.

Keep enums from bindings.executor to avoid accidental re-exports and import SINGLE_BEAM_WIDTH used by stores.

-from ..pyexecutor.llm_request import FinishReason, LlmRequest, LlmRequestState
+from tensorrt_llm.bindings.executor import FinishReason
+from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
 from ..pyexecutor.sampler import (BEAM_0, Sampler, SampleState,
-                                  SampleStateTensors, TorchSampler, TorchStore,
+                                  SampleStateTensors, TorchSampler, TorchStore,
+                                  SINGLE_BEAM_WIDTH,
                                   add_token, int_tensor)

87-98: Use fill_/zero_ instead of copy_ with a scalar.

copy_ expects a Tensor; use .fill_(0) or .zero_() for clarity and to avoid dtype/device pitfalls.

-                    self.mtp_relaxed_delta_pool[slot_id].copy_(
-                        0, non_blocking=True)
+                    self.mtp_relaxed_delta_pool[slot_id].fill_(0)
@@
-            self.mtp_relaxed_delta_pool[free_slot_id].copy_(0,
-                                                            non_blocking=True)
+            self.mtp_relaxed_delta_pool[free_slot_id].fill_(0)

688-707: torch.compile usage: consider guarding by availability and env flags.

Compiling small kernels is fine, but if users run without torch 2.x or with inductor disabled, this can regress. Optional: add a module-level flag or env check to bypass compile in debug/tests.

Would you like a follow-up patch to add a simple feature flag (e.g., TRTLLM_ENABLE_COMPILE) around these @torch.compile sites?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between e3d1d9d and 981ef52.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (13 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (14 hunks)
  • tests/unittest/_torch/test_torch_sampler.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/test_torch_sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/test_torch_sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
🧠 Learnings (1)
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (3)
tests/unittest/_torch/test_torch_sampler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • convert_wordlist (414-448)
  • LlmRequest (271-411)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
  • TorchSampler (333-818)
  • Args (336-341)
  • _write_finish_reasons (668-697)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
tensorrt_llm/_torch/pyexecutor/llm_request.py (3)
  • finish_by (382-385)
  • LlmRequest (271-411)
  • get_draft_token_length (545-556)
tensorrt_llm/_torch/speculative/mtp.py (4)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (271-411)
  • finish_by (382-385)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
  • SlotManager (934-963)
  • get_slot (941-942)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)
  • Sampler (51-68)
  • TorchSampler (333-818)
  • TorchStore (311-320)
  • int_tensor (303-304)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

756-756: Line too long (124 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

668-699: Finish-reason precedence is sensible; END_ID > LENGTH > STOP_WORDS.

Writing STOP_WORDS, then LENGTH, then END_ID ensures the highest-priority reason wins for the same step. Matches expectations and maps to TRTLLM semantics.

tensorrt_llm/_torch/speculative/mtp.py (2)

302-306: Index dtypes look good here; ensure slots is long on CUDA.

slots comes from torch.as_tensor([...]) and remains int64 after to(cuda), which is correct for index_copy_. LGTM.


616-686: Good use of manager-owned pools; slot_ids indexing is correct.

Using manager.slot_ids for gather/scatter keeps ownership centralized and avoids stale per-request arrays. Nice.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/unittest/_torch/test_torch_sampler.py (1)

123-125: Use torch.long for seq_slots indices on CUDA.

CUDA advanced indexing requires int64 indices. Using int32 here can error or behave inconsistently.

Apply this diff:

-    seq_slots = torch.tensor([req.request.py_seq_slot for req in requests],
-                             device="cuda",
-                             dtype=torch.int32)
+    seq_slots = torch.tensor(
+        [req.request.py_seq_slot for req in requests],
+        device="cuda",
+        dtype=torch.long,
+    )
🧹 Nitpick comments (5)
tests/unittest/_torch/test_torch_sampler.py (5)

6-7: Import LlmRequest from its defining module (avoid relying on re-exports).

LlmRequest is defined in llm_request.py; importing it from sampler can be brittle if sampler stops re-exporting it.

Apply this diff:

-from tensorrt_llm._torch.pyexecutor.sampler import (BEAM_0, LlmRequest,
-                                                    TorchSampler)
+from tensorrt_llm._torch.pyexecutor.sampler import BEAM_0, TorchSampler
+from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest

16-23: Use Optional[...] for nullable annotations (Python 3.8 compatibility).

Type hints with default None should be Optional[...] to reflect nullability and keep 3.8-friendly syntax.

Apply this diff:

+from typing import Optional
@@
     def __init__(self,
                  *,
                  prompt: list[int],
                  new_tokens: list[int],
-                 finish_reasons: list[FinishReason],
+                 finish_reasons: list[FinishReason],
                  max_new_tokens: int = MAX_NEW_TOKENS,
-                 end_id: int = None,
-                 stop_words_list: list[list[int]] = None):
+                 end_id: Optional[int] = None,
+                 stop_words_list: Optional[list[list[int]]] = None):

45-52: Make the string a real function docstring and keep lines <=120.

The triple-quoted string is not a docstring where it is placed. Move it to immediately after def and reflow to satisfy E501.

Apply this diff:

-def test_write_finish_reasons():
-    NOT_FINISHED = FinishReason.NOT_FINISHED
-    STOP_WORDS = FinishReason.STOP_WORDS
-    END_ID = FinishReason.END_ID
-    LENGTH = FinishReason.LENGTH
-    """We don't really care about the finish reason past the first infraction, because we're not going to use it, although in some instance it is written anyway."""
+def test_write_finish_reasons():
+    """Validate _write_finish_reasons early-exit and precedence behavior.
+
+    We do not care about finish reasons past the first infraction for usage,
+    though later writes may still occur in some cases.
+    """
+    NOT_FINISHED = FinishReason.NOT_FINISHED
+    STOP_WORDS = FinishReason.STOP_WORDS
+    END_ID = FinishReason.END_ID
+    LENGTH = FinishReason.LENGTH

104-107: Reflow long comment (E501).

Line exceeds 120 chars. Reflow for readability.

Apply this diff:

-            # The latest infraction check overrides the earlier infraction checks, hence the first finish_reason is END_ID
+            # The latest infraction check overrides earlier checks,
+            # hence the first finish_reason is END_ID.

118-121: Match new_tokens dtype to store to avoid unintended type promotions.

Safer to bind the dtype to sampler.store.new_tokens.dtype instead of hard-coding int32.

Apply this diff:

-    new_tokens = torch.tensor([req.new_tokens for req in requests],
-                              dtype=torch.int32,
-                              device="cuda").T.unsqueeze(-1)
+    new_tokens = torch.tensor(
+        [req.new_tokens for req in requests],
+        dtype=sampler.store.new_tokens.dtype,
+        device="cuda",
+    ).T.unsqueeze(-1)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 981ef52 and 37fb3a6.

📒 Files selected for processing (1)
  • tests/unittest/_torch/test_torch_sampler.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/test_torch_sampler.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/test_torch_sampler.py
🧬 Code graph analysis (1)
tests/unittest/_torch/test_torch_sampler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • convert_wordlist (414-448)
  • LlmRequest (271-411)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
  • TorchSampler (333-818)
  • Args (336-341)
  • _write_finish_reasons (668-697)
🪛 Ruff (0.12.2)
tests/unittest/_torch/test_torch_sampler.py

42-42: Line too long (178 > 120)

(E501)


50-50: Line too long (164 > 120)

(E501)


105-105: Line too long (122 > 120)

(E501)

🔇 Additional comments (1)
tests/unittest/_torch/test_torch_sampler.py (1)

52-108: Nice coverage of finish-reason scenarios.

Good matrix across NOT_FINISHED, STOP_WORDS (including lookback), END_ID, and LENGTH with precedence/override cases. This test should guard the new stop-criteria logic well.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

1-1: Add NVIDIA copyright header (2025).

Project guidelines require the NVIDIA header at the top of all source files.

Apply:

+ # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
♻️ Duplicate comments (4)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

313-321: TorchStore: persist sizing fields and fix stray triple-quoted string.

Persist max_draft_len/max_num_sequences/max_beam_width; convert the no-op string literal to a comment.

Apply:

 class TorchStore:
@@
     def __init__(self, *, max_draft_len: int, max_num_sequences: int,
                  max_beam_width: int):
-        self.max_tokens = max_draft_len + 1
+        # Persist sizing for downstream stores
+        self.max_draft_len = max_draft_len
+        self.max_num_sequences = max_num_sequences
+        self.max_beam_width = max_beam_width
+        self.max_tokens = max_draft_len + 1
         assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
         self.new_tokens = int_tensor(
             (self.max_tokens, max_num_sequences, max_beam_width))
-        """Shape: See cpp DecoderState.getAllNewTokens()"""
+        # Shape: See cpp DecoderState.getAllNewTokens()
         self.finish_reasons = int_tensor(self.new_tokens.shape)
tests/unittest/_torch/test_torch_sampler.py (3)

43-45: Fix repr to use py_ fields and wrap line (prevents AttributeError and E501).*

Apply:

-    def __repr__(self):
-        return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
+    def __repr__(self):
+        return (
+            "RequestCase("
+            f"prompt={self.prompt}, "
+            f"new_tokens={self.new_tokens}, "
+            f"finish_reasons={self.finish_reasons}, "
+            f"py_max_new_tokens={self.request.py_max_new_tokens}, "
+            f"py_end_id={self.request.py_end_id}, "
+            f"py_stop_words_list={self.request.py_stop_words_list}"
+            ")"
+        )

146-148: Replace zip(..., strict=True) for Python 3.8+.

Apply:

-    for actual, request in zip(actual_finish_reasons, requests, strict=True):
+    assert len(actual_finish_reasons) == len(requests)
+    for actual, request in zip(actual_finish_reasons, requests):

1-1: Add NVIDIA copyright header (2025).

Tests are also source files and must carry the header.

Apply:

+ # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)

479-481: Use py_seq_slot for consistency with Python wrapper API.

request.seq_slot mixes CPP and Python fields; the rest of this file uses py_seq_slot. Align for consistency.

Apply:

-            new_tokens[i, request.seq_slot, BEAM_0] = new_token
+            new_tokens[i, request.py_seq_slot, BEAM_0] = new_token
@@
-            new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token
+            new_tokens[num_accepted, request.py_seq_slot, BEAM_0] = new_token

Also applies to: 488-490


361-375: Generator is not per-device; may break on multi-GPU.

Caching a single torch.Generator tied to the first device can misfire on later devices. Cache per device.

Apply:

-        if self._generator is None:
-            # Fallback to a default seed if not set
-            self._generator = torch.Generator(device=device)
-            self._generator.manual_seed(self._global_seed)
-        return self._generator
+        if not hasattr(self, "_generators"):
+            self._generators = {}
+        dev = torch.device(device)
+        gen = self._generators.get(dev)
+        if gen is None:
+            gen = torch.Generator(device=dev)
+            gen.manual_seed(self._global_seed)
+            self._generators[dev] = gen
+        return gen

776-777: Wrap long assertion to satisfy E501.

Line exceeds 120 chars.

Apply:

-            assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling"
+            assert "d2t" not in model_outputs, (
+                "eagle3 does not yet support non-greedy sampling"
+            )
tests/unittest/_torch/test_torch_sampler.py (2)

1-3: Seed RNG for deterministic tests.

Random seq_slots make tests flaky across runs.

Apply:

 import random
+random.seed(0)
@@
-    seq_slots = random.sample(range(MAX_NUM_SEQUENCES), MAX_NUM_SEQUENCES)
+    seq_slots = random.sample(range(MAX_NUM_SEQUENCES), MAX_NUM_SEQUENCES)

Also applies to: 16-17


52-52: Wrap long lines flagged by Ruff E501.

Keep under 120 chars for doc/comment strings and assertions.

Apply:

-    """We don't really care about the finish reason past the first infraction, because we're not going to use it, although in some instance it is written anyway."""
+    """We don't care about reasons past first infraction; they may still be written in some cases."""
-    # fill with garbage value so we can observe that finish reasons are filled with NOT_FINISHED before we write to them.
+    # Fill with a garbage value so we can observe that finish_reasons are first set to NOT_FINISHED.
-            # The latest infraction check overrides the earlier infraction checks, hence the first finish_reason is END_ID
+            # Latest infraction overrides earlier checks; first finish_reason is END_ID.

Also applies to: 65-67, 122-124

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 37fb3a6 and ddba022.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (13 hunks)
  • tests/unittest/_torch/test_torch_sampler.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tests/unittest/_torch/test_torch_sampler.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tests/unittest/_torch/test_torch_sampler.py
🧠 Learnings (2)
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py} : Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Applied to files:

  • tests/unittest/_torch/test_torch_sampler.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
tensorrt_llm/_torch/pyexecutor/llm_request.py (4)
  • finish_by (382-385)
  • LlmRequest (271-411)
  • log_probs (220-221)
  • get_draft_token_length (545-556)
tests/unittest/_torch/test_torch_sampler.py (3)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • convert_wordlist (414-448)
  • LlmRequest (271-411)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)
  • TorchSampler (333-822)
  • sample (276-288)
  • Args (336-341)
  • _write_finish_reasons (671-701)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

760-760: Line too long (124 > 120)

(E501)

tests/unittest/_torch/test_torch_sampler.py

44-44: Line too long (178 > 120)

(E501)


52-52: Line too long (164 > 120)

(E501)


65-65: Line too long (121 > 120)

(E501)


122-122: Line too long (122 > 120)

(E501)

🔇 Additional comments (1)
tests/unittest/_torch/test_torch_sampler.py (1)

129-132: LGTM: seq_slots dtype int64 for CUDA indexing.

Index tensor uses torch.int64, matching index_fill_/advanced indexing requirements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)

31-35: Python 3.8 compatibility: Replace PEP 604 union syntax and kw_only dataclasses

The dataclass definitions use kw_only=True and PEP 604 union syntax (torch.Tensor | None) which are incompatible with Python 3.8. The coding guidelines specify Python 3.8+ compatibility.

Apply this fix for Python 3.8 compatibility:

+from typing import Optional, List, Dict, Union
-from typing import Literal, Optional
+from typing import Literal, Optional, List, Dict, Union

-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTensors:
     new_tokens: torch.Tensor
-    logits: torch.Tensor | None = None
-    log_probs: torch.Tensor | None = None
+    logits: Optional[torch.Tensor] = None
+    log_probs: Optional[torch.Tensor] = None

-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTensorsHostTorch(SampleStateTensors):
     finish_reasons: torch.Tensor

-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTorch(SampleState):
     host: SampleStateTensorsHostTorch

Also applies to: 323-326, 328-331


279-289: Python 3.8 compatibility: Replace match statement with if/elif chain

The match statement is a Python 3.10+ feature and violates the Python 3.8+ requirement in the coding guidelines.

Replace the match statement with an if/elif chain:

-    match strategy:
-        case ("top_k", top_k):
-            return top_k_sampling_batch(logits, top_k, generator)
-        case ("top_p", top_p, temperature):
-            return top_p_sampling_batch(logits, top_p, temperature, generator)
-        case ("top_k_top_p", top_k, top_p, temperature):
-            return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
-                                              generator)
-        case ("greedy", None):
-            return greedy_search_sampling_batch(logits)
+    if strategy[0] == "top_k":
+        _, top_k = strategy
+        return top_k_sampling_batch(logits, top_k, generator)
+    elif strategy[0] == "top_p":
+        _, top_p, temperature = strategy
+        return top_p_sampling_batch(logits, top_p, temperature, generator)
+    elif strategy[0] == "top_k_top_p":
+        _, top_k, top_p, temperature = strategy
+        return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
+                                          generator)
+    elif strategy[0] == "greedy":
+        return greedy_search_sampling_batch(logits)
+    else:
+        raise ValueError(f"Unknown sampling strategy: {strategy}")

383-399: Python 3.8 compatibility: Replace PEP 585 generic types in function signatures

Multiple function signatures use built-in generic types like list[list[int]] which require Python 3.9+. These need to be replaced with typing module equivalents.

Update function signatures to use typing module generics:

     @staticmethod
-    def _meet_stop_token_criteria(py_stop_words_list: list[list[int]] | None,
-                                  tokens: list[int]) -> bool:
+    def _meet_stop_token_criteria(py_stop_words_list: Optional[List[List[int]]],
+                                  tokens: List[int]) -> bool:

-    def _write_finish_reasons(self, requests: list[LlmRequest], *,
+    def _write_finish_reasons(self, requests: List[LlmRequest], *,
                               finish_reasons: torch.Tensor,
                               seq_slots: torch.Tensor,
                               new_tokens: torch.Tensor) -> None:

-    def _are_end_id(self, requests: list[LlmRequest],
+    def _are_end_id(self, requests: List[LlmRequest],
                     tokens: torch.Tensor) -> torch.Tensor:

-    def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
+    def _are_max_length(self, requests: List[LlmRequest]) -> torch.Tensor:

-    def _are_stop_words(self, requests: list[LlmRequest], tokens: torch.Tensor,
+    def _are_stop_words(self, requests: List[LlmRequest], tokens: torch.Tensor,
                         longest_stop_word_len: int) -> torch.Tensor:

     def _process_requests(self,
-                          requests: list[LlmRequest],
-                          model_outputs: dict[str, torch.Tensor],
+                          requests: List[LlmRequest],
+                          model_outputs: Dict[str, torch.Tensor],
                           new_tokens: torch.Tensor,

Also applies to: 671-747


58-59: Python 3.8 compatibility: Replace PEP 604 union syntax in return type

The return type annotation uses PEP 604 syntax which is incompatible with Python 3.8.

-    def get_cache_indirection(self) -> torch.Tensor | None:
+    def get_cache_indirection(self) -> Optional[torch.Tensor]:

756-757: Python 3.8 compatibility: Replace PEP 604 union syntax in function parameters

The function parameters use PEP 604 union syntax (torch.Tensor | None) which is only available in Python 3.10+.

Update the function signature:

                           gen_logits_host: torch.Tensor | None = None,
                           log_probs_host: torch.Tensor | None = None):
+                          gen_logits_host: Optional[torch.Tensor] = None,
+                          log_probs_host: Optional[torch.Tensor] = None):
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)

311-321: LGTM: TorchStore class design with minor documentation fix

The TorchStore design is well-structured with proper initialization and validation. However, there's a documentation formatting issue.

Fix the stray docstring after the new_tokens attribute:

         self.new_tokens = int_tensor(
             (self.max_tokens, max_num_sequences, max_beam_width))
-        """Shape: See cpp DecoderState.getAllNewTokens()"""
+        # Shape: See cpp DecoderState.getAllNewTokens()
         self.finish_reasons = int_tensor(self.new_tokens.shape)

704-724: Review performance impact of finish reason calculations

The finish reason detection methods (_are_end_id, _are_max_length, _are_stop_words) create and transfer multiple tensors to CUDA for each sampling step. While functionally correct, this could impact performance in high-throughput scenarios.

Consider batching tensor operations or caching frequently accessed tensors if performance profiling indicates these operations are bottlenecks. However, given the "Chill" review setting and the learning that performance is prioritized in TensorRT-LLM sampler code, this may be acceptable as-is for correctness.

Also applies to: 726-747


762-762: Minor: Line length exceeds 120 characters

Static analysis detected a line that exceeds the 120-character limit.

-        fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
+        fast_path = (not self.enable_mixed_sampler and no_draft_tokens and 
+                     gen_logits_host is None and log_probs_host is None)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ddba022 and b02ec58.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (13 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (2)
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.py : Code must target Python 3.8+

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/pyexecutor/llm_request.py (3)
  • finish_by (382-385)
  • LlmRequest (271-411)
  • get_draft_token_length (545-556)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
tensorrt_llm/_torch/speculative/mtp.py (1)
  • _handle_stop_criteria (228-245)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

762-762: Line too long (124 > 120)

(E501)

🔇 Additional comments (8)
tensorrt_llm/_torch/pyexecutor/sampler.py (8)

6-6: LGTM: numpy import for stop word length calculations

The numpy import is used appropriately for the np.max and np.diff functions in the _longest_stop_word_len method.


307-308: LGTM: Beam constants renamed for clarity

The renaming from BEAM to BEAM_0 and introduction of SINGLE_BEAM_WIDTH makes the beam indexing more explicit and readable.


351-356: LGTM: TorchStore integration in TorchSampler

The integration of TorchStore into TorchSampler is clean and properly handles the inference mode context for tensor allocation.


377-380: LGTM: Consistent BEAM_0 usage in stop criteria

The update to use BEAM_0 instead of hardcoded 0 improves code consistency and readability.


650-663: LGTM: Efficient stop word length calculation

The _longest_stop_word_len method efficiently calculates the maximum stop word length using numpy operations, which is appropriate for performance-sensitive sampling code.


665-702: LGTM: Well-structured finish reason writing logic

The finish reason writing implementation is well-designed with:

  • Clear separation of concerns across helper methods
  • Proper precedence ordering (NOT_FINISHED → STOP_WORDS → LENGTH → END_ID)
  • Good documentation explaining the precedence logic

559-574: Fixed: Correct dtype for seq_slots tensor

The seq_slots tensor now uses torch.int64 dtype which is appropriate for index operations like index_copy_ and index_fill_. This addresses the previous review feedback.


571-590: LGTM: Comprehensive finish reason integration

The finish reason pipeline is well-integrated into the sampling flow:

  • Proper tensor transfers to CPU with non_blocking=True
  • Clean return of SampleStateTorch with finish_reasons in host state
  • Maintains existing sampling event synchronization

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
@netanel-haber netanel-haber force-pushed the user/nhaber/feature/trtllm-7153-move-stop-critera-to-sample-async branch from b02ec58 to d9d1e61 Compare August 27, 2025 12:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (13)
tensorrt_llm/_torch/pyexecutor/sampler.py (13)

1-10: Add NVIDIA copyright header (2025) at file top.

Project guidelines require the NVIDIA copyright header on all source files.

Apply this diff:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from abc import ABC, abstractmethod

35-39: Replace PEP 604 unions for Python 3.8 compatibility.

The torch.Tensor | None syntax requires Python 3.10+. Project targets Python 3.8+.

Apply this diff:

+from typing import Optional
 
 @dataclass(kw_only=True)
 class SampleStateTensors:
     new_tokens: torch.Tensor
-    logits: torch.Tensor | None = None
-    log_probs: torch.Tensor | None = None
+    logits: Optional[torch.Tensor] = None
+    log_probs: Optional[torch.Tensor] = None

35-46: Remove kw_only from dataclasses for Python 3.8 compatibility.

The kw_only=True parameter requires Python 3.10+. Project targets Python 3.8+.

Apply this diff:

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensors:
     new_tokens: torch.Tensor
     logits: Optional[torch.Tensor] = None
     log_probs: Optional[torch.Tensor] = None

-@dataclass(kw_only=True)
+@dataclass
 class SampleState:
     scheduled_requests: ScheduledRequests
     device: SampleStateTensors = None
     host: SampleStateTensors = None
     sampler_event: torch.cuda.Event = None

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensorsHostTorch(SampleStateTensors):
     finish_reasons: torch.Tensor

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTorch(SampleState):
     host: SampleStateTensorsHostTorch

Also applies to: 323-331


62-62: Replace PEP 604 union for Python 3.8 compatibility.

Apply this diff:

-    def get_cache_indirection(self) -> torch.Tensor | None:
+    def get_cache_indirection(self) -> Optional[torch.Tensor]:

283-292: Replace match/case statement for Python 3.8 compatibility.

The match/case syntax was introduced in Python 3.10. Project targets Python 3.8+.

Apply this diff:

-    match strategy:
-        case ("top_k", top_k):
-            return top_k_sampling_batch(logits, top_k, generator)
-        case ("top_p", top_p, temperature):
-            return top_p_sampling_batch(logits, top_p, temperature, generator)
-        case ("top_k_top_p", top_k, top_p, temperature):
-            return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
-                                              generator)
-        case ("greedy", None):
-            return greedy_search_sampling_batch(logits)
+    if strategy[0] == "top_k":
+        return top_k_sampling_batch(logits, strategy[1], generator)
+    elif strategy[0] == "top_p":
+        return top_p_sampling_batch(logits, strategy[1], strategy[2], generator)
+    elif strategy[0] == "top_k_top_p":
+        return top_k_top_p_sampling_batch(logits, strategy[1], strategy[2], strategy[3], generator)
+    elif strategy[0] == "greedy":
+        return greedy_search_sampling_batch(logits)

249-254: Replace PEP 604 unions for Python 3.8 compatibility.

Apply this diff:

+from typing import Union, Tuple
 
-TopK = tuple[Literal["top_k"], int]
-TopP = tuple[Literal["top_p"], float, float]
-TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
-Greedy = tuple[Literal["greedy"], None]
+TopK = Tuple[Literal["top_k"], int]
+TopP = Tuple[Literal["top_p"], float, float]
+TopKTopP = Tuple[Literal["top_k_top_p"], int, float, float]
+Greedy = Tuple[Literal["greedy"], None]
 GREEDY: Greedy = ("greedy", None)
-Strategy = TopK | TopP | Greedy
+Strategy = Union[TopK, TopP, TopKTopP, Greedy]

276-277: Replace list[...] generic syntax for Python 3.8 compatibility.

Built-in generic syntax like list[LlmRequest] requires Python 3.9+. Project targets Python 3.8+.

Apply this diff (showing sample locations):

+from typing import List, Dict
 
-def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]:
+def sampling_strategies(requests: Iterable[LlmRequest]) -> List[Strategy]:
 
-def _write_finish_reasons(self, requests: list[LlmRequest], *,
+def _write_finish_reasons(self, requests: List[LlmRequest], *,
 
-def _are_end_id(self, requests: list[LlmRequest],
+def _are_end_id(self, requests: List[LlmRequest],
 
-def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
+def _are_max_length(self, requests: List[LlmRequest]) -> torch.Tensor:
 
-def _are_stop_words(self, requests: list[LlmRequest], tokens: torch.Tensor,
+def _are_stop_words(self, requests: List[LlmRequest], tokens: torch.Tensor,
 
-def request_stop_words(request: LlmRequest,
-                       new_tokens: torch.Tensor) -> list[bool]:
-    per_step = [False] * self.max_tokens
+def request_stop_words(request: LlmRequest,
+                       new_tokens: torch.Tensor) -> List[bool]:
+    per_step: List[bool] = [False] * self.max_tokens
 
-def _process_requests(self,
-                      requests: list[LlmRequest],
-                      model_outputs: dict[str, torch.Tensor],
+def _process_requests(self,
+                      requests: List[LlmRequest],
+                      model_outputs: Dict[str, torch.Tensor],

Also applies to: 634-637, 666-674, 675-686, 688-709, 711-719


738-738: Replace PEP 604 union for Python 3.8 compatibility.

Apply this diff:

-        batched_strategy: Strategy | None = GREEDY
+        batched_strategy: Optional[Strategy] = GREEDY

955-955: Replace remaining PEP 604 unions for Python 3.8 compatibility.

Apply this diff:

-    def get_cache_indirection(self) -> torch.Tensor | None:
+    def get_cache_indirection(self) -> Optional[torch.Tensor]:
 
 @dataclass(kw_only=True)
 class SampleStateTensorsHostTRTLLM(SampleStateTensors):
     finished_sum: torch.Tensor
     finish_reasons: torch.Tensor
     sequence_lengths: torch.Tensor
-    cum_log_probs: torch.Tensor | None = None
-    gathered_ids: torch.Tensor | None = None
+    cum_log_probs: Optional[torch.Tensor] = None
+    gathered_ids: Optional[torch.Tensor] = None
 
 @dataclass(kw_only=True)
 class SampleStateTRTLLM(SampleState):
-    finalize_events: dict[str, CudaEvent] | None = None
+    finalize_events: Optional[Dict[str, CudaEvent]] = None

Also applies to: 799-812


513-514: Replace dict[...] generic syntax for Python 3.8 compatibility.

Apply this diff:

     def sample_async(self, scheduled_requests: ScheduledRequests,
-                     model_outputs: dict[str, torch.Tensor]) -> SampleState:
+                     model_outputs: Dict[str, torch.Tensor]) -> SampleState:

563-576: Replace list[...] and list[int] for Python 3.8 compatibility.

Apply this diff:

     def _apply_embedding_bias(
             logits: torch.Tensor,
-            requests: list[LlmRequest],
-            steps_per_request: list[int] = None) -> torch.Tensor:
+            requests: List[LlmRequest],
+            steps_per_request: List[int] = None) -> torch.Tensor:

799-812: Remove kw_only from remaining dataclasses for Python 3.8 compatibility.

Apply this diff:

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensorsHostTRTLLM(SampleStateTensors):
     finished_sum: torch.Tensor
     finish_reasons: torch.Tensor
     sequence_lengths: torch.Tensor
     cum_log_probs: Optional[torch.Tensor] = None
     gathered_ids: Optional[torch.Tensor] = None

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTRTLLM(SampleState):
     finalize_events: Optional[Dict[str, CudaEvent]] = None
     """..."""
     host: SampleStateTensorsHostTRTLLM

335-342: Replace frozen=True, kw_only=True with kw_only=True for Args dataclass.

The change from frozen=True, kw_only=True to just kw_only=True is acceptable but note that kw_only=True also requires Python 3.10+. Since we need to support Python 3.8+, remove the kw_only parameter entirely.

Apply this diff:

-    @dataclass(frozen=True, kw_only=True)
+    @dataclass(frozen=True)
     class Args:
         max_seq_len: int
         max_draft_len: int
         max_num_sequences: int
         max_beam_width: int
         enable_mixed_sampler: bool
♻️ Duplicate comments (5)
tests/unittest/_torch/test_torch_sampler.py (3)

1-10: Add NVIDIA copyright header (2025) at file top.

Project guidelines require the NVIDIA copyright header on all source files.

Apply this diff:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import random
 
 import torch

43-45: Fix repr: use py_ fields to prevent AttributeError.*

LlmRequest exposes py_* attributes. Accessing non-existent ones in __repr__ will break diagnostics.

Apply this diff:

     def __repr__(self):
-        return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
+        return (
+            "RequestCase("
+            f"prompt={self.prompt}, "
+            f"new_tokens={self.new_tokens}, "
+            f"finish_reasons={self.finish_reasons}, "
+            f"py_max_new_tokens={self.request.py_max_new_tokens}, "
+            f"py_end_id={self.request.py_end_id}, "
+            f"py_stop_words_list={self.request.py_stop_words_list}"
+            ")"
+        )

146-152: Replace zip(..., strict=True): not available on Python 3.8.

The strict=True parameter was added in Python 3.10. Target Python 3.8+ requires compatibility fix.

Apply this diff:

-    for actual, request in zip(actual_finish_reasons, requests, strict=True):
+    assert len(actual_finish_reasons) == len(requests)
+    for actual, request in zip(actual_finish_reasons, requests):
         expected = request.finish_reasons
tensorrt_llm/_torch/speculative/mtp.py (1)

235-238: Persist max_seq_len for stop criteria.

MTPSampler later uses self.max_seq_len in stop checks but doesn't store it during initialization.

The implementation correctly stores self.max_seq_len = args.max_seq_len on line 238.

tensorrt_llm/_torch/pyexecutor/sampler.py (1)

718-719: Replace PEP 604 union for Python 3.8 compatibility.

Apply this diff:

-                          gen_logits_host: torch.Tensor | None = None,
-                          log_probs_host: torch.Tensor | None = None):
+                          gen_logits_host: Optional[torch.Tensor] = None,
+                          log_probs_host: Optional[torch.Tensor] = None):
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b02ec58 and d9d1e61.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (4 hunks)
  • tests/unittest/_torch/test_torch_sampler.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/test_torch_sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/test_torch_sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py
🧠 Learnings (3)
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py} : Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Applied to files:

  • tests/unittest/_torch/test_torch_sampler.py
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.py : Code must target Python 3.8+

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py
🧬 Code graph analysis (4)
tests/unittest/_torch/test_torch_sampler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • convert_wordlist (414-448)
  • LlmRequest (271-411)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
  • sample (280-292)
  • _write_finish_reasons (634-664)
tensorrt_llm/_torch/speculative/mtp.py (2)
tensorrt_llm/_torch/pyexecutor/sampler.py (7)
  • Sampler (55-72)
  • SampleState (46-52)
  • SampleStateTensors (36-42)
  • TorchSampler (333-786)
  • TorchStore (311-320)
  • add_token (295-304)
  • int_tensor (307-308)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (1)
  • handle_stop_1_beam (33-50)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)
tensorrt_llm/_torch/pyexecutor/handle_logits.py (1)
  • HandleLogits (10-66)
tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py (1)
  • MakeDecodingBatchInputOutput (11-70)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (2)
  • handle_stop_1_beam (33-50)
  • stop_token_criteria (15-30)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (271-411)
  • get_draft_token_length (545-556)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (2)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (271-411)
  • finish_by (382-385)
🪛 Ruff (0.12.2)
tests/unittest/_torch/test_torch_sampler.py

44-44: Line too long (178 > 120)

(E501)


52-52: Line too long (164 > 120)

(E501)


65-65: Line too long (121 > 120)

(E501)


122-122: Line too long (122 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/sampler.py

724-724: Line too long (124 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/mtp.py (1)

261-261: Good adaptation to the new stop-handling pattern.

The migration to handle_stop_1_beam maintains consistency across the codebase.

tensorrt_llm/_torch/pyexecutor/sampler.py (2)

311-321: Clean design with TorchStore abstraction.

The refactoring to use TorchStore provides cleaner separation of concerns and better encapsulation of storage buffers.


627-665: Well-structured finish-reason tracking implementation.

The finish-reason tracking logic is cleanly organized with proper precedence handling (STOP_WORDS -> LENGTH -> END_ID) where later checks override earlier ones, ensuring correct priority.

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)

1-1: Add NVIDIA copyright header (2025) at file top.

Required by coding guidelines.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

433-466: Use consistent seq slot attribute; fix potential AttributeError.

Elsewhere the code uses request.py_seq_slot. request.seq_slot may not exist.

-            new_tokens[i, request.seq_slot, BEAM_0] = new_token
+            new_tokens[i, request.py_seq_slot, BEAM_0] = new_token
...
-            new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token
+            new_tokens[num_accepted, request.py_seq_slot, BEAM_0] = new_token
♻️ Duplicate comments (5)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)

35-41: Python 3.8 compatibility: remove kw_only and PEP 604 unions.

Switch to Optional[...] and plain @DataClass() per guidelines.

-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTensors:
     new_tokens: torch.Tensor
-    logits: torch.Tensor | None = None
-    log_probs: torch.Tensor | None = None
+    logits: Optional[torch.Tensor] = None
+    log_probs: Optional[torch.Tensor] = None

311-321: TorchStore: persist sizing fields and fix stray attribute docstring.

Keep sizing for downstream users and convert the no-op string to a comment.

 class TorchStore:
 
     def __init__(self, *, max_draft_len: int, max_num_sequences: int,
                  max_beam_width: int):
-        self.max_tokens = max_draft_len + 1
+        # Persist sizing for downstream stores
+        self.max_draft_len = max_draft_len
+        self.max_num_sequences = max_num_sequences
+        self.max_beam_width = max_beam_width
+        self.max_tokens = max_draft_len + 1
         assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
         self.new_tokens = int_tensor(
             (self.max_tokens, max_num_sequences, max_beam_width))
-        """Shape: See cpp DecoderState.getAllNewTokens()"""
+        # Shape: see cpp DecoderState.getAllNewTokens()
         self.finish_reasons = int_tensor(self.new_tokens.shape)

35-41: Repo-wide: audit for remaining 3.9+/3.10+ features.

There are other occurrences in this file (e.g., list[...] and match/case) that violate the 3.8 target.

#!/usr/bin/env bash
# Find 3.9+/3.10+ features that break Python 3.8
rg -nP '\|\s*None|\b(list|tuple|dict|set)\s*\[|@dataclass\(\s*kw_only\s*=' tensorrt_llm/_torch/pyexecutor/sampler.py
rg -nP '^\s*match\b' tensorrt_llm/_torch/pyexecutor/sampler.py

Also applies to: 323-331, 727-731


323-331: Python 3.8 compatibility: remove kw_only on dataclasses.

Same rationale as above.

-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTensorsHostTorch(SampleStateTensors):
     finish_reasons: torch.Tensor
 
-@dataclass(kw_only=True)
+@dataclass()
 class SampleStateTorch(SampleState):
     host: SampleStateTensorsHostTorch

727-731: Python 3.8 compatibility: replace PEP 604 unions in signature.

Also consider tightening types for requests/model_outputs if desired.

-                          gen_logits_host: torch.Tensor | None = None,
-                          log_probs_host: torch.Tensor | None = None):
+                          gen_logits_host: Optional[torch.Tensor] = None,
+                          log_probs_host: Optional[torch.Tensor] = None):
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)

6-6: Avoid NumPy for a simple diff/max; keep dependencies minimal.

This use can be expressed in pure Python; drop the import.

-import numpy as np

(See replacement in _longest_stop_word_len below.)


623-637: Rewrite longest stop-word length without NumPy.

Removes extra dependency and implicit CPU work.

-    def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
-        max_stop_word_len = 0
-        for req in requests:
-            if req.py_stop_words_list is None:
-                continue
-            _, cumsum = req.py_stop_words_list
-            if -1 in cumsum:
-                cumsum = cumsum[:cumsum.index(-1)]
-            request_max_stop_word_len = np.max(np.diff(cumsum, prepend=0),
-                                               initial=0)
-            max_stop_word_len = max(max_stop_word_len,
-                                    request_max_stop_word_len)
-        return max_stop_word_len
+    def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
+        max_len = 0
+        for req in requests:
+            if req.py_stop_words_list is None:
+                continue
+            _, cumsum = req.py_stop_words_list
+            if -1 in cumsum:
+                cumsum = cumsum[:cumsum.index(-1)]
+            prev = 0
+            for end in cumsum:
+                max_len = max(max_len, end - prev)
+                prev = end
+        return max_len

699-721: Avoid implicit GPU→CPU sync via .tolist() on CUDA tensors.

Explicitly copy the column to CPU before tolist() to document the transfer and allow non_blocking when possible.

-            new_tokens_list = new_tokens.tolist()
+            # Explicit GPU->CPU hop; avoid implicit sync in tolist()
+            new_tokens_list = new_tokens.to(device="cpu", non_blocking=True).tolist()

784-797: Minor: keep linter happy; shorten assert message.

Long string can trigger E501; condense wording.

-                assert BEAM_0 == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
+                assert BEAM_0 == 0, "Assumes beam_width == 1 (unsqueeze path)"
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d9d1e61 and e59385d.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (2)
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.py : Code must target Python 3.8+

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)
tensorrt_llm/_torch/pyexecutor/handle_logits.py (1)
  • HandleLogits (10-66)
tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py (1)
  • MakeDecodingBatchInputOutput (11-70)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (2)
  • handle_stop_1_beam (33-50)
  • stop_token_criteria (15-30)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
tensorrt_llm/_torch/pyexecutor/llm_request.py (3)
  • LlmRequest (271-411)
  • finish_by (382-385)
  • get_draft_token_length (545-556)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

735-735: Line too long (124 > 120)

(E501)

🔇 Additional comments (14)
tensorrt_llm/_torch/pyexecutor/sampler.py (14)

12-15: Sampler utils import consolidation looks good.


351-356: Store initialization path is clear.

Using inference_mode(False) to allow in-place writes is appropriate.


394-406: Finish-reason plumbing: solid and minimal.

The enum set and the classmethod wrapper are fine.


479-487: Update path for context requests looks correct.

Token add, finish check, and logits handling are in the right order.


494-505: Generation update integrates drafting paths cleanly.

Counting accepted tokens and emitting logits per processed step is good.


511-514: Log-probs host buffer shape matches SINGLE_BEAM_WIDTH.

Pinned CPU allocation is appropriate.


518-522: Generation logits host buffer shape is consistent.


532-537: Index dtype fix is correct.

seq_slots_host as int64 avoids index_fill_/index_copy_ dtype errors.


545-548: Finish-reason threading to host is correct; event recorded.

Return of SampleStateTorch with populated host/device is coherent.

Also applies to: 550-563


639-644: CUDA writes for finish_reasons look correct.

Indexing via nonzero(as_tuple=True) is efficient.


646-676: Finish-reason precedence and NOT_FINISHED reset are sensible.

Reverse precedence order is clear; index_fill_ avoids stale slots.


677-685: End-ID comparison tensor creation LGTM.

Shape and dtype align with tokens; pinned-to-cuda copy is fine.


686-698: Max-length mask creation LGTM.

Computation matches spec: generated_len >= min(max_new_tokens, max_seq_len).


744-745: Fast-path index_copy_ usage is correct.

The view to (1, batch, 1) and scatter by seq_slots is efficient.

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
…itera-to-sample-async

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (9)
tensorrt_llm/_torch/pyexecutor/sampler.py (9)

35-40: Replace PEP 604 unions and remove kw_only for Python 3.8 compatibility

The dataclass uses features not available in Python 3.8:

  • kw_only=True requires Python 3.10+
  • X | None syntax requires Python 3.10+
-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensors:
     new_tokens: torch.Tensor
-    logits: torch.Tensor | None = None
-    log_probs: torch.Tensor | None = None
+    logits: Optional[torch.Tensor] = None
+    log_probs: Optional[torch.Tensor] = None

45-46: Remove kw_only for Python 3.8 compatibility

-@dataclass(kw_only=True)
+@dataclass
 class SampleState:

62-63: Replace PEP 604 union syntax

-    def get_cache_indirection(self) -> torch.Tensor | None:
+    def get_cache_indirection(self) -> Optional[torch.Tensor]:
         return None

283-293: Replace match statement with if/elif for Python 3.8 compatibility

match statements (PEP 622) require Python 3.10+. Replace with if/elif chain.

 def sample(strategy: Strategy,
            logits: torch.Tensor,
            generator: Optional[torch.Generator] = None):
-    match strategy:
-        case ("top_k", top_k):
-            return top_k_sampling_batch(logits, top_k, generator)
-        case ("top_p", top_p, temperature):
-            return top_p_sampling_batch(logits, top_p, temperature, generator)
-        case ("top_k_top_p", top_k, top_p, temperature):
-            return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
-                                              generator)
-        case ("greedy", None):
-            return greedy_search_sampling_batch(logits)
+    if strategy[0] == "top_k":
+        return top_k_sampling_batch(logits, strategy[1], generator)
+    elif strategy[0] == "top_p":
+        return top_p_sampling_batch(logits, strategy[1], strategy[2], generator)
+    elif strategy[0] == "top_k_top_p":
+        return top_k_top_p_sampling_batch(logits, strategy[1], strategy[2], strategy[3],
+                                          generator)
+    elif strategy[0] == "greedy":
+        return greedy_search_sampling_batch(logits)

524-525: Replace dict type hint

     def sample_async(self, scheduled_requests: ScheduledRequests,
-                     model_outputs: dict[str, torch.Tensor]) -> SampleState:
+                     model_outputs: Dict[str, torch.Tensor]) -> SampleState:

573-575: Replace list type hints

     def _apply_embedding_bias(
             logits: torch.Tensor,
-            requests: list[LlmRequest],
-            steps_per_request: list[int] = None) -> torch.Tensor:
+            requests: List[LlmRequest],
+            steps_per_request: List[int] = None) -> torch.Tensor:

810-817: Remove kw_only and replace PEP 604 unions

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensorsHostTRTLLM(SampleStateTensors):
     finished_sum: torch.Tensor
     finish_reasons: torch.Tensor
     sequence_lengths: torch.Tensor
-    cum_log_probs: torch.Tensor | None = None
-    gathered_ids: torch.Tensor | None = None
+    cum_log_probs: Optional[torch.Tensor] = None
+    gathered_ids: Optional[torch.Tensor] = None

819-823: Remove kw_only and replace PEP 604 unions

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTRTLLM(SampleState):
-    finalize_events: dict[str, CudaEvent] | None = None
+    finalize_events: Optional[Dict[str, CudaEvent]] = None

966-967: Replace PEP 604 union

-    def get_cache_indirection(self) -> torch.Tensor | None:
+    def get_cache_indirection(self) -> Optional[torch.Tensor]:
         return self.store["decoder_state"].cache_indirection_output
♻️ Duplicate comments (7)
tensorrt_llm/_torch/pyexecutor/sampler.py (7)

4-4: Use Optional for Python 3.8 compatibility

The project targets Python 3.8+ according to the coding guidelines, but PEP 604 union syntax (X | None) requires Python 3.10+. Replace with Optional[X] for compatibility.

-from typing import Literal, Optional
+from typing import Literal, Optional, List, Dict

Apply this import change and update all occurrences of PEP 604 unions throughout the file.


645-676: Replace List type hint and optimize stop words check

Use List from typing module for Python 3.8 compatibility. Also, the walrus operator in the if statement is good for readability.

-    def _write_finish_reasons(self, requests: list[LlmRequest], *,
+    def _write_finish_reasons(self, requests: List[LlmRequest], *,
                               finish_reasons: torch.Tensor,
                               seq_slots: torch.Tensor,
                               new_tokens: torch.Tensor) -> None:

323-326: Remove kw_only for Python 3.8 compatibility

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTensorsHostTorch(SampleStateTensors):
     finish_reasons: torch.Tensor

677-685: Replace list type hint

-    def _are_end_id(self, requests: list[LlmRequest],
+    def _are_end_id(self, requests: List[LlmRequest],
                     tokens: torch.Tensor) -> torch.Tensor:

686-698: Replace list type hint

-    def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
+    def _are_max_length(self, requests: List[LlmRequest]) -> torch.Tensor:

722-731: Replace type hints for Python 3.8 compatibility

     def _process_requests(self,
-                          requests: list[LlmRequest],
-                          model_outputs: dict[str, torch.Tensor],
+                          requests: List[LlmRequest],
+                          model_outputs: Dict[str, torch.Tensor],
                           new_tokens: torch.Tensor,
                           *,
                           seq_slots: torch.Tensor,
                           seq_slots_host: torch.Tensor,
-                          gen_logits_host: torch.Tensor | None = None,
-                          log_probs_host: torch.Tensor | None = None):
+                          gen_logits_host: Optional[torch.Tensor] = None,
+                          log_probs_host: Optional[torch.Tensor] = None):

328-331: Remove kw_only for Python 3.8 compatibility

-@dataclass(kw_only=True)
+@dataclass
 class SampleStateTorch(SampleState):
     host: SampleStateTensorsHostTorch
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)

433-434: Improve clarity of docstring

The docstring explains an implementation constraint but could be clearer.

-        """We cannot use finish_if_reason in _process_draft_tokens_rejection_sampling because it *writes to new_tokens*,
-        rendering the finish reason calculation in sample_async stale (incorrect) for this batch"""
+        """Process draft tokens using rejection sampling.
+        
+        Note: Cannot use finish_if_reason here as it reads from new_tokens which
+        we're still populating, making the finish reason calculation incorrect."""

639-644: Add validation for tensor device consistency

The assertion could be more explicit about the requirement.

     @staticmethod
     def _write_reason(finish_reasons: torch.Tensor, reason: FinishReason, *,
                       where: torch.Tensor, seq_slots: torch.Tensor) -> None:
-        assert all([seq_slots.is_cuda, where.is_cuda])
+        assert seq_slots.is_cuda and where.is_cuda, "Tensors must be on CUDA device"
         r, c = torch.nonzero(where, as_tuple=True)
         finish_reasons[r, seq_slots[c], BEAM_0] = reason.value

735-735: Consider shortening line length

Line exceeds 120 character limit, though readability is acceptable.

-        fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
+        fast_path = (not self.enable_mixed_sampler and no_draft_tokens and 
+                     gen_logits_host is None and log_probs_host is None)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between e59385d and 925cf3b.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (9 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/pyexecutor/sampler_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (2)
📚 Learning: 2025-08-25T22:42:47.587Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T22:42:47.587Z
Learning: Applies to **/*.py : Code must target Python 3.8+

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)
tensorrt_llm/_torch/pyexecutor/handle_logits.py (1)
  • HandleLogits (10-66)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (2)
  • handle_stop_1_beam (48-65)
  • stop_token_criteria (30-45)
cpp/include/tensorrt_llm/executor/types.h (1)
  • FinishReason (502-597)
tensorrt_llm/_torch/pyexecutor/llm_request.py (4)
  • LlmRequest (271-411)
  • finish_by (382-385)
  • get_draft_token_length (545-556)
  • log_probs (220-221)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

735-735: Line too long (124 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)

532-536: Specify dtype=torch.long for index operations

The seq_slots tensor needs to be of type long for use in index_fill_ and other indexing operations. While you specify dtype=torch.int64, it's clearer to use torch.long.

Good catch on specifying the dtype. For clarity, consider using torch.long which is an alias for int64 and more commonly used for indexing:

         seq_slots_host = torch.tensor(
             [r.py_seq_slot for r in requests],
-            dtype=torch.int64,  # for index_fill_
+            dtype=torch.long,
             pin_memory=True)

407-414: Well-structured finish reason handling

Good implementation of the finish_if_reason method with proper step tracking for greedy decoding.


544-563: Clean integration of finish reasons into the sampling pipeline

Excellent addition of _write_finish_reasons to centralize the finish reason computation and proper propagation through the SampleStateTorch return value.

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
@netanel-haber netanel-haber marked this pull request as ready for review August 27, 2025 13:49
@netanel-haber netanel-haber requested review from a team as code owners August 27, 2025 13:49
@netanel-haber
Copy link
Collaborator Author

/bot run

@netanel-haber netanel-haber marked this pull request as draft August 27, 2025 13:50
@netanel-haber netanel-haber self-assigned this Aug 27, 2025
@tensorrt-cicd
Copy link
Collaborator

PR_Github #16695 [ run ] triggered by Bot

@netanel-haber netanel-haber requested a review from dcampora August 27, 2025 13:56
@netanel-haber
Copy link
Collaborator Author

/bot run --stage-list DGX_B200-4_GPUs-PyTorch-1

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17884 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17884 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13392 (Partly Tested) completed with status: 'FAILURE'

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
@netanel-haber
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17887 [ run ] triggered by Bot

@netanel-haber netanel-haber enabled auto-merge (squash) September 6, 2025 17:42
@netanel-haber netanel-haber enabled auto-merge (squash) September 6, 2025 17:42
@tensorrt-cicd
Copy link
Collaborator

PR_Github #17887 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13395 completed with status: 'FAILURE'

@netanel-haber
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17901 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17901 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13409 completed with status: 'FAILURE'

@netanel-haber
Copy link
Collaborator Author

/bot run --only-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17918 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17918 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13425 (Partly Tested) completed with status: 'SUCCESS'

@netanel-haber
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17941 [ reuse-pipeline ] triggered by Bot

@netanel-haber
Copy link
Collaborator Author

Single gpu success: Since then, just a single commit that adds a single line to support multi-gpu was added.
Multi gpu success.

There are flaky timeouts for single gpu runs, see here:

unittest/B200_PCIe-PyTorch-1/unittest/_torch/thop/parallel/test_moe.py::TestMoeFp4::test_autotune[RoutingRenormalize-768-1024-1]
H100_PCIe-PyTorch-1/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17941 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #17918 (Partly Tested) for commit 0afadcf

@netanel-haber netanel-haber merged commit 0fee8cd into NVIDIA:main Sep 7, 2025
5 checks passed
netanel-haber added a commit to netanel-haber/TensorRT-LLM that referenced this pull request Sep 17, 2025
netanel-haber added a commit to netanel-haber/TensorRT-LLM that referenced this pull request Sep 17, 2025
…IA#7041)"

This reverts commit 0fee8cd.

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
mikeiovine added a commit that referenced this pull request Sep 18, 2025
…ple_async (#7041) (#7796)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Mike Iovine <miovine@nvidia.com>
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…to sample_async (NVIDIA#7041) (NVIDIA#7796)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Mike Iovine <miovine@nvidia.com>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
…to sample_async (NVIDIA#7041) (NVIDIA#7796)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Mike Iovine <miovine@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants