KEMBAR78
[None][fix] using arrival time in llmapi when creating LlmRequest in pytorch workflow by zhengd-nv · Pull Request #7553 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@zhengd-nv
Copy link
Collaborator

@zhengd-nv zhengd-nv commented Sep 5, 2025

Summary by CodeRabbit

  • New Features

    • Added optional arrival_time across Python APIs and bindings; propagated to requests and used in performance metrics (defaults to current time when metrics are enabled).
    • Introduced steady_clock_now() utility to fetch a monotonic timestamp from Python.
  • Chores

    • Removed TRTLLM_KVCACHE_TIME_OUTPUT_PATH-based auto-enablement of performance metrics in the OpenAI-compatible chat path.

Description

Currently, the arrivalTime in RequestPerfMetrics is set when creating LlmRequest class. In PyTorch workflow, the creation is after queuing and cannot reflect the real request arrival time. Because of this, the perf metrics cannot capture the queuing time in PyTorch workflow. In this PR, the arrival time is recorded in llmapi and is passed as an argument for LlmRequest creation.

This PR also add a steady_clock_now() binding to python, allowing consistent timing between C++ code and Python code.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@zhengd-nv
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

📝 Walkthrough

Walkthrough

Adds optional arrival_time propagation across Python and C++ layers for LlmRequest, including type alias TimePoint, constructor parameter additions, and bindings updates (nanobind/pybind). Provides Python-accessible steady_clock_now helpers. Wires arrival_time from API through executor and worker to backend. Removes env-var-based perf-metrics toggle in OpenAI server path.

Changes

Cohort / File(s) Summary
Core request types
cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Introduces TimePoint alias and appends optional arrivalTime to GenericLlmRequest and LlmRequest constructors; sets perf metrics arrival time using provided value or steady_clock::now() when enabled; updates constructor forwarding and comments.
Nanobind batch_manager
cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp, cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp, cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
Extends bindings and wrapper types to accept/forward optional arrival_time/arrivalTime; includes nanobind/stl/chrono.h; updates parameter counts and constructor calls.
Nanobind module utils
cpp/tensorrt_llm/nanobind/bindings.cpp
Adds nanobind/stl/chrono.h and exposes steady_clock_now() returning std::chrono::steady_clock::time_point.
Pybind batch_manager
cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp, cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp, cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
Adds optional arrival_time/arrivalTime to Python bindings and forwards to C++; updates constructor signatures, calls, and parameter count comments.
Pybind module utils
cpp/tensorrt_llm/pybind/bindings.cpp
Includes pybind11/chrono.h and exposes steady_clock_now() returning steady_clock::time_point.
Executor pipeline
tensorrt_llm/executor/executor.py, tensorrt_llm/executor/request.py, tensorrt_llm/executor/worker.py
Adds optional arrival_time to generate_async and GenerationRequest; stores and forwards via worker to backend request attribute py_arrival_time.
Torch pyexecutor glue
tensorrt_llm/_torch/pyexecutor/llm_request.py
Forwards executor_request.py_arrival_time to LlmRequest(arrival_time=...) during conversion.
High-level LLM API
tensorrt_llm/llmapi/llm.py
Imports steady_clock_now; computes arrival_time when return_perf_metrics is enabled and passes it into executor.
OpenAI server
tensorrt_llm/serve/openai_server.py
Removes env-var toggle (TRTLLM_KVCACHE_TIME_OUTPUT_PATH) that previously forced return_perf_metrics=True in chat path.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Client
  participant LLMAPI as LLM API (Python)
  participant Exec as GenerationExecutor
  participant Worker
  participant Backend as Backend LlmRequest (C++)

  Client->>LLMAPI: generate_async(prompt, sampling_params)
  alt return_perf_metrics enabled
    note over LLMAPI: arrival_time = steady_clock_now()
  else
    note over LLMAPI: arrival_time = None
  end
  LLMAPI->>Exec: generate_async(..., arrival_time)
  Exec->>Exec: GenerationRequest(..., arrival_time)
  Exec->>Worker: _enqueue_request(request)
  Worker->>Worker: executor_request.py_arrival_time = request.arrival_time
  Worker->>Backend: Build LlmRequest(arrivalTime=py_arrival_time)
  Backend->>Backend: If returnPerfMetrics: set timing.arrivalTime (use provided or now)
  Backend-->>Client: Results (+perf metrics if requested)
Loading
sequenceDiagram
  autonumber
  participant Py as Python Bindings
  participant NB as Nanobind C++
  participant PB as Pybind C++
  participant Core as Core C++ (GenericLlmRequest)

  Py->>NB: LlmRequest(..., arrival_time)
  NB->>Core: LlmRequest(..., arrivalTime)
  Py->>PB: LlmRequest(..., arrival_time)
  PB->>Core: LlmRequest(..., arrivalTime)
  Note over Core: Constructor stores arrivalTime or steady_clock::now()
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • syuoni
  • Superjomn
  • chzblych
  • tomeras91
  • pcastonguay
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp (1)

79-131: Do not hardcode returnPerfMetrics=false when converting to tb::LlmRequest.

This defeats perf-metrics collection even when the Python-side request asked for it. It also undermines the usefulness of the newly propagated arrivalTime.

Apply this diff:

-        false,                                                     // returnPerfMetrics
+        mReturnPerfMetrics,                                        // returnPerfMetrics
cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (1)

99-99: py::classh appears to be a typo (will not compile).

pybind11 uses py::class_, not py::classh. This looks like a mechanical typo and will break the build.

-    py::classh<GenLlmReq>(m, "GenericLlmRequest")
+    py::class_<GenLlmReq>(m, "GenericLlmRequest")
 ...
-    py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
+    py::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
 ...
-    py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
+    py::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
 ...
-    py::classh<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")
+    py::class_<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")

Also applies to: 261-261, 391-391, 399-399

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (1)

54-56: Expose steady_clock_now binding
Python code in tensorrt_llm/llmapi/llm.py (line 375) calls steady_clock_now, but no such nanobind export exists. Add this to cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp inside initBindings:

m.def("steady_clock_now", []() {
    return std::chrono::steady_clock::now();
}, "Return a std::chrono::steady_clock::time_point for use as arrival_time.");
🧹 Nitpick comments (16)
cpp/tensorrt_llm/nanobind/bindings.cpp (1)

516-517: Expose monotonic timestamp helpers to make arrival_time IPC-safe and cross-binding friendly.

Returning a steady_clock::time_point is fine when everything stays in-process and within the same binding tech, but it is fragile across process boundaries (pickling) and between pybind/nanobind. Please also expose ns helpers so we can serialize an int64 and reconstruct the time_point at the consumer.

Apply this diff to add helpers alongside steady_clock_now:

     m.def("ipc_nvls_supported", &tr::ipcNvlsSupported);
 
-    m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
+    // Monotonic time_point (same as C++)
+    m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
+    // Monotonic timestamp in nanoseconds since steady_clock epoch (IPC/pickle friendly)
+    m.def("steady_clock_now_ns",
+          []() {
+              return std::chrono::duration_cast<std::chrono::nanoseconds>(
+                         std::chrono::steady_clock::now().time_since_epoch())
+                  .count();
+          });
+    // Reconstruct a steady_clock::time_point from ns
+    m.def("steady_clock_from_ns",
+          [](long long ns) {
+              return std::chrono::time_point<std::chrono::steady_clock>{
+                  std::chrono::nanoseconds{ns}};
+          });
cpp/tensorrt_llm/pybind/bindings.cpp (1)

503-504: Mirror nanobind helpers to keep APIs symmetrical and enable ns-based transport.

Expose ns helpers here too so the Python layer can depend on a stable API regardless of binding backend.

Apply this diff:

     m.def("ipc_nvls_supported", &tr::ipcNvlsSupported);
 
-    m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
+    // Monotonic time_point (same as C++)
+    m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
+    // Monotonic timestamp in nanoseconds since steady_clock epoch (IPC/pickle friendly)
+    m.def("steady_clock_now_ns",
+          []() {
+              return std::chrono::duration_cast<std::chrono::nanoseconds>(
+                         std::chrono::steady_clock::now().time_since_epoch())
+                  .count();
+          });
+    // Reconstruct a steady_clock::time_point from ns
+    m.def("steady_clock_from_ns",
+          [](long long ns) {
+              return std::chrono::time_point<std::chrono::steady_clock>{
+                  std::chrono::nanoseconds{ns}};
+          });
cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h (2)

54-54: Replace brittle “50 parameters” comment with a stable note

The literal count drifts easily and adds maintenance burden. Prefer a descriptive note tied to the Base constructor.

-    // 50 parameters
+    // Parameters: keep in sync with Base constructor order

88-90: Avoid const-ref to std::optional defaulting to std::nullopt

Binding a const reference parameter to a temporary optional (std::nullopt) is safe here but fragile if constructor internals change. Passing by value avoids lifetime pitfalls and copies at most a small optional wrapper.

-        std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
+        std::optional<executor::ContextPhaseParams> contextPhaseParams = std::nullopt,
         std::optional<TimePoint> arrivalTime = std::nullopt)

Also consider a short Doxygen note for the new arrivalTime parameter (steady_clock time point).

tensorrt_llm/executor/request.py (1)

100-101: Clarify arrival_time units in docstring/comments

Add a brief note that arrival_time is from steady_clock (monotonic) and specify units (e.g., seconds as float) to prevent ambiguity.

cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp (1)

78-78: Make the constructor comment future-proof

Same rationale: the numeric parameter count is brittle.

-    // 50 parameters
+    // Parameters: keep in sync with tb::LlmRequest constructor order
tensorrt_llm/executor/executor.py (1)

127-128: Update generate_async docstring to include arrival_time

Add a short param line for arrival_time (steady_clock, float seconds, optional) to keep API docs accurate.

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (2)

339-339: Update the parameter-count comment.

Comment still says “49 parameters” after adding arrival_time; make it “50” to avoid drift.

-                    // 49 parameters
+                    // 50 parameters

369-369: Nit: fix Python kwarg name typo.

max_endocer_input_len → max_encoder_input_len (public API kwarg). If backward-compat is a concern, consider supporting both temporarily.

-            py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
+            py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_encoder_input_len") = std::nullopt,
cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h (3)

87-89: Be explicit about the TimePoint alias for clarity.

TimePoint comes from the Base; add a using to mirror other aliases and avoid lookup surprises.

 public:
     using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
+    using TimePoint = Base::TimePoint;

1-16: Copyright header year.

Files should reflect the current year (2025) per guidelines.

- * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

18-18: Include guards preferred over pragma once (repo guideline).

Consider adding a proper include guard per TRTLLM naming (optional if repo intentionally standardized on pragma once).

-#pragma once
+#ifndef TRTLLM_LLMREQUEST_H
+#define TRTLLM_LLMREQUEST_H
 ...
+#endif // TRTLLM_LLMREQUEST_H
cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (1)

321-321: Update the parameter-count comment.

Reflect the new total.

-                // 49 parameters
+                // 50 parameters
cpp/include/tensorrt_llm/batch_manager/llmRequest.h (3)

311-315: Consider allowing executor::Request to carry arrival time.

This constructor defaults arrivalTime to now(). If some flows create GenericLlmRequest from executor::Request after queuing, you’ll still misreport arrival. If feasible, add an optional arrivalTime to executor::Request and propagate it here.

Would you like a follow-up patch sketch for executor::Request to carry an optional steady_clock::time_point?


1-16: License header year.

Update to include 2025 per repo guidelines.

- * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION.  All rights reserved.

17-17: Include guards vs pragma once.

Headers under cpp/include should prefer TRTLLM_* include guards.

-#pragma once
+#ifndef TRTLLM_BATCH_MANAGER_LLMREQUEST_H
+#define TRTLLM_BATCH_MANAGER_LLMREQUEST_H
 ...
+#endif // TRTLLM_BATCH_MANAGER_LLMREQUEST_H
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ff37048 and 8cdcf8a.

📒 Files selected for processing (15)
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h (9 hunks)
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (4 hunks)
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp (2 hunks)
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h (3 hunks)
  • cpp/tensorrt_llm/nanobind/bindings.cpp (2 hunks)
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (3 hunks)
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp (2 hunks)
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h (3 hunks)
  • cpp/tensorrt_llm/pybind/bindings.cpp (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/llm_request.py (1 hunks)
  • tensorrt_llm/executor/executor.py (2 hunks)
  • tensorrt_llm/executor/request.py (2 hunks)
  • tensorrt_llm/executor/worker.py (1 hunks)
  • tensorrt_llm/llmapi/llm.py (3 hunks)
  • tensorrt_llm/serve/openai_server.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/serve/openai_server.py
🧰 Additional context used
📓 Path-based instructions (8)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/executor/worker.py
  • cpp/tensorrt_llm/nanobind/bindings.cpp
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • tensorrt_llm/llmapi/llm.py
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
  • tensorrt_llm/executor/request.py
  • tensorrt_llm/executor/executor.py
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/executor/worker.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/llmapi/llm.py
  • tensorrt_llm/executor/request.py
  • tensorrt_llm/executor/executor.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/executor/worker.py
  • cpp/tensorrt_llm/nanobind/bindings.cpp
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • tensorrt_llm/llmapi/llm.py
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
  • tensorrt_llm/executor/request.py
  • tensorrt_llm/executor/executor.py
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/nanobind/bindings.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/nanobind/bindings.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/nanobind/bindings.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
  • cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
  • cpp/include/tensorrt_llm/batch_manager/llmRequest.h
🧠 Learnings (1)
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h
  • cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (19)
cpp/tensorrt_llm/nanobind/bindings.cpp (1)

22-22: Chrono support for nanobind — looks good.

Required to pass std::chrono::time_point across the binding.

tensorrt_llm/executor/worker.py (1)

595-597: Remove IPC/pickling concern: arrival_time is a Python float and already serializes correctly
The arrival_time field on GenerationRequest and ExecutorRequest is defined as Optional[float] (seconds) and is copied to py_arrival_time as a native Python float—no std::chrono::time_point object is involved, so it’s pickle-friendly and handled numerically end-to-end.

Likely an incorrect or invalid review comment.

cpp/tensorrt_llm/pybind/bindings.cpp (1)

19-19: Chrono support for pybind — good alignment with nanobind.

Needed for time_point bindings.

cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h (1)

150-151: Arrival time propagation into Base looks correct

Parameter is appended at the end and forwarded without reordering. Matches the intended design.

tensorrt_llm/executor/request.py (1)

126-126: LGTM: request stores arrival_time

The field is optional and doesn’t affect existing flows.

cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp (1)

128-130: Arrival time forwarded to tb::LlmRequest — OK

Argument order and types look consistent; forwarding mPerfMetrics.timingMetrics.arrivalTime is correct.

tensorrt_llm/llmapi/llm.py (3)

352-354: Compute arrival_time via steady_clock when metrics enabled — good

Captures entry time in the API layer with a monotonic clock.


451-452: Propagate arrival_time to executor — good

Keeps the change localized and backward compatible.


21-21: Confirm steady_clock_now is exported in both binding variants
Both the nanobind and pybind11 bindings register the symbol via m.def("steady_clock_now", …) in their respective module definitions (NB_MODULE in cpp/tensorrt_llm/nanobind/bindings.cpp and PYBIND11_MODULE in cpp/tensorrt_llm/pybind/bindings.cpp), ensuring it’s available across all build variants.

tensorrt_llm/executor/executor.py (1)

151-153: LGTM: forward arrival_time into GenerationRequest

Placement and defaulting preserve existing behavior.

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp (1)

296-298: Remove unnecessary steady_clock helper suggestion
A steady_clock_now factory is already exposed in the root pybind module (cpp/tensorrt_llm/pybind/bindings.cpp:503), so callers can use tensorrt_llm.steady_clock_now() for arrival_time.

Likely an incorrect or invalid review comment.

cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h (2)

53-53: Keep the parameter-count comment in sync.

Now 50 parameters; comment updated correctly here. LGTM.


89-151: Constructor forwarding for arrivalTime looks correct.

Arrival time is appended after contextPhaseParams and forwarded to Base. LGTM.

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp (2)

35-35: Good: chrono support for TimePoint.

Including nanobind/stl/chrono.h is required for TimePoint bindings. LGTM.


291-334: Arrival time threading through nanobind constructor looks correct.

Parameter added and forwarded to tb::LlmRequest; default exposed in kwargs. LGTM.

Also applies to: 358-358

cpp/include/tensorrt_llm/batch_manager/llmRequest.h (4)

103-104: TimePoint alias looks right.

Alias uses steady_clock; consistent with timeout and perf metrics. LGTM.


105-105: Keep parameter-count comments accurate.

Comment updated to 50; good.


141-143: API surface: new optional arrivalTime param.

Placement after contextPhaseParams is sensible and preserves ABI for existing call sites using defaults. LGTM.


200-204: Only set arrivalTime when metrics requested — OK.

Guarding by mReturnPerfMetrics prevents unnecessary work. LGTM.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17739 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17739 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User zhengd-nv is not in the prioritized list. L0 testing cannot be triggered.

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17984 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17984 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13479 completed with status: 'FAILURE'

@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18012 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18012 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13499 completed with status: 'FAILURE'

@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18167 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18167 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13613 completed with status: 'FAILURE'

@pcastonguay
Copy link
Collaborator

/bot run --disable-fail-fast

@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18276 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18276 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13700 completed with status: 'FAILURE'

@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18299 [ run ] triggered by Bot

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
@zhengd-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18541 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18541 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13918 completed with status: 'SUCCESS'

Copy link
Collaborator

@nv-guomingz nv-guomingz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for LLM API part.

Copy link
Collaborator

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pcastonguay pcastonguay merged commit 24fc1f9 into NVIDIA:main Sep 15, 2025
5 checks passed
@zhengd-nv zhengd-nv deleted the arrival-time branch September 18, 2025 08:47
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…pytorch workflow (NVIDIA#7553)

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
…pytorch workflow (NVIDIA#7553)

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants