KEMBAR78
[None][feat] Use list instead of torch tensor for new tokens in update requests by dcampora · Pull Request #7730 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@dcampora
Copy link
Collaborator

@dcampora dcampora commented Sep 15, 2025

Summary by CodeRabbit

  • New Features
    • Support adding multiple tokens per request in a single call.
    • Token updates are applied atomically per cycle via caching for more consistent behavior across drafting and rejection paths.
    • Adjusted single-beam single-step flows to align with batched token handling.
  • Performance
    • Reduced per-token overhead by batching token additions, improving throughput and latency in sampling.
  • Documentation
    • Updated usage guidance to provide tokens as per-request containers matching the order of requests.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@dcampora dcampora requested a review from a team as a code owner September 15, 2025 12:21
@dcampora dcampora requested a review from Naveassaf September 15, 2025 12:21
@dcampora
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 15, 2025

📝 Walkthrough

Walkthrough

Bindings in C++ (nanobind and pybind) change add_new_tokens_to_requests to accept per-request token lists. Python sampler introduces token caching to defer mutations and commits them in batches via the updated binding. Control flow in drafting/rejection paths routes token additions through the cache and commits at cycle end.

Changes

Cohort / File(s) Summary
C++ bindings: per-request token containers
cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp, cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Change binding signature to accept std::vector<std::vector<TokenIdType>> tokens per request; validate outer size vs. requests; iterate inner vectors and push tokens with beam_idx; update docstrings/messages.
Python sampler: cached token emission and batched commit
tensorrt_llm/_torch/pyexecutor/sampler.py
add_token gains optional cache param; introduce self._cached_tokens lifecycle in update paths; route all token additions through cache in drafting and rejection flows; commit via add_new_tokens_to_requests(requests, per_request_tokens, beam); adjust single-beam single-step to nested token lists.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Sampler
  participant Cache as Per-request Token Cache
  participant Binding as add_new_tokens_to_requests
  participant LlmReq as LlmRequest

  rect rgb(245,245,255)
    note over Sampler: New flow (cached, batched commit)
    Sampler->>Cache: add_token(request, token, beam, step, cache)
    Sampler->>Cache: ... repeat per path (draft/reject)
    Sampler->>Binding: add_new_tokens_to_requests(requests, tokens_per_request, beam)
    Binding->>LlmReq: for each request: push each token
  end
Loading
sequenceDiagram
  autonumber
  participant Sampler
  participant Binding as add_new_token (old usage)
  participant LlmReq as LlmRequest

  rect rgb(255,245,245)
    note over Sampler: Previous flow (immediate mutation)
    Sampler->>Binding: add_new_token(request, token, beam)
    Binding->>LlmReq: push token
    Sampler->>Binding: ... repeated per step
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The PR description remains as the unmodified template with no actual content in the Description or Test Coverage sections and leaves all implementation rationale and test details blank, so it fails to explain what was changed, why, or how it is tested. Please replace the template placeholders with a concise summary of the feature and its motivation in the Description section, list specific test cases or coverage in the Test Coverage section, and update the checklist to reflect completed items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title “[None][feat] Use list instead of torch tensor for new tokens in update requests” directly references a valid aspect of the changeset—the transition from tensors to per-request lists for new tokens—but it does not highlight the primary optimization goal of batching and atomic commits; however, it still accurately describes a concrete part of the implementation.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18629 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

1-1: Add NVIDIA Apache-2.0 header.

Project guideline requires the license header at the top of all Python files.

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
🧹 Nitpick comments (4)
cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (1)

438-454: Make the binding more efficient and robust (const&, size_t loop, release GIL).

  • Take requests by const& (you mutate the pointees, not the vector).
  • Use size_t for the index to match requests.size().
  • Release the GIL while pushing tokens to reduce contention on large batches.
-        [](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
-            std::vector<std::vector<tb::LlmRequest::TokenIdType>> const& tokens, int beam_idx)
+        [](std::vector<std::shared_ptr<tb::LlmRequest>> const& requests,
+            std::vector<std::vector<tb::LlmRequest::TokenIdType>> const& tokens, int beam_idx)
         {
             TLLM_CHECK_WITH_INFO(
                 requests.size() == tokens.size(), "Expected the same number of requests and token containers.");
 
-            for (int i = 0; i < requests.size(); ++i)
+            for (size_t i = 0; i < requests.size(); ++i)
             {
                 for (const auto& token : tokens[i])
                 {
                     requests[i]->addNewToken(token, beam_idx);
                 }
             }
         },
-        py::arg("requests"), py::arg("tokens"), py::arg("beam_idx"),
+        py::arg("requests"), py::arg("tokens"), py::arg("beam_idx"),
+        py::call_guard<py::gil_scoped_release>(),
         "Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all "
         "requests in order.");
cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (1)

431-447: Apply the same const&, size_t loop, and GIL release in nanobind binding.

Keeps both bindings consistent and reduces overhead under contention.

-        [](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
-            std::vector<std::vector<tb::LlmRequest::TokenIdType>> const& tokens, int beam_idx)
+        [](std::vector<std::shared_ptr<tb::LlmRequest>> const& requests,
+            std::vector<std::vector<tb::LlmRequest::TokenIdType>> const& tokens, int beam_idx)
         {
             TLLM_CHECK_WITH_INFO(
                 requests.size() == tokens.size(), "Expected the same number of requests and token containers.");
 
-            for (int i = 0; i < requests.size(); ++i)
+            for (size_t i = 0; i < requests.size(); ++i)
             {
                 for (const auto& token : tokens[i])
                 {
                     requests[i]->addNewToken(token, beam_idx);
                 }
             }
         },
-        nb::arg("requests"), nb::arg("tokens"), nb::arg("beam_idx"),
+        nb::arg("requests"), nb::arg("tokens"), nb::arg("beam_idx"),
+        nb::call_guard<nb::gil_scoped_release>(),
         "Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all "
         "requests in order.");
tensorrt_llm/_torch/pyexecutor/sampler.py (2)

350-352: Fix type hint: avoid implicit Optional (RUF013).

Annotate cache as Optional to satisfy typing and linters.

-def add_token(request: LlmRequest,
-              new_tokens: torch.Tensor,
-              *,
-              beam: int,
-              step: int = 0,
-              cache: dict[LlmRequest, List[int]] = None) -> int:
+def add_token(request: LlmRequest,
+              new_tokens: torch.Tensor,
+              *,
+              beam: int,
+              step: int = 0,
+              cache: Optional[dict[LlmRequest, List[int]]] = None) -> int:

594-597: Stabilize request/token pairing on commit.

keys() and values() preserve order in CPython, but zip(items) avoids any mismatch risk and reads clearer.

-        # Commit adding tokens
-        add_new_tokens_to_requests(list(self._cached_tokens.keys()),
-                                   list(self._cached_tokens.values()), BEAM_0)
+        # Commit adding tokens
+        items = list(self._cached_tokens.items())
+        if items:
+            reqs, toks = zip(*items)
+            add_new_tokens_to_requests(list(reqs), list(toks), BEAM_0)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 24fc1f9 and ed65c45.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (1 hunks)
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (7 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.

Applied to files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
🧬 Code graph analysis (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • LlmRequest (284-426)
tensorrt_llm/_torch/pyexecutor/sampler_utils.py (1)
  • handle_stop_single_beam (44-61)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py

351-351: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18629 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13987 completed with status: 'FAILURE'

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
@dcampora dcampora force-pushed the user/dcampora/optimize_update_requests branch from ed65c45 to 7a2ebae Compare September 22, 2025 04:30
@dcampora dcampora requested a review from a team as a code owner September 22, 2025 04:30
@dcampora dcampora requested a review from nv-yilinf September 22, 2025 04:30
@dcampora dcampora changed the title [None][feat] Draft: Perform a single binding call for update requests [None][feat] Use list instead of torch tensor for new tokens in update requests Sep 22, 2025
@dcampora dcampora enabled auto-merge (squash) September 22, 2025 04:30
@dcampora
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19495 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19495 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14655 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@dcampora dcampora merged commit 9f1d9b7 into NVIDIA:main Sep 23, 2025
7 of 9 checks passed
@dcampora dcampora deleted the user/dcampora/optimize_update_requests branch September 24, 2025 07:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants