-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[#4593][feat] AutoDeploy: Linear Attention Support (SSM + causal_conv + Bamba + Nemotron-H) #8068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[#4593][feat] AutoDeploy: Linear Attention Support (SSM + causal_conv + Bamba + Nemotron-H) #8068
Conversation
* [None][auto_deploy] Bamba Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> * debugging export accuracy diff for bamba Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --------- Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
* correct enable_block_reuse arg in serve Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> * multi_modal_data handling in ADInputProcessor Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> * correct rescale Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> * correct handling of kvcache config in trtllm-serve Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --------- Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
* Fix the bamba unit test Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * none: Add triton backend for ssm_transform and cuda backend for conv Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fully Use the TRT LLM kernels Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Add fake version for ssm transform op Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the datatype error in fake op Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the conv test error Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the triton ssm error Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --------- Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…es with better reset/sizing (#140) Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
* Fix the bamba unit test Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * none: Add triton backend for ssm_transform and cuda backend for conv Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fully Use the TRT LLM kernels Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Add fake version for ssm transform op Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the datatype error in fake op Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the conv test error Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the triton ssm error Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the DemoLLM sampler mismatch Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Update the implementation for triton/cuda kernels Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Fix the d2d memcpy for decode Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> * Revert the generator and remove the redundant code Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --------- Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
* [None][feat] Add patches for NemotronH Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> * [None][test] unittest for nemotron_h Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> * nemotron-h support finished Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> * added anticapted path for new models on llm_models trt-llm CI Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --------- Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds slot_idx propagation throughout attention metadata and cached ops; introduces cached SSM (Mamba) and causal conv backends (Torch, Triton, CUDA); extends transform registry/config to initialize new caches; updates model patches (Bamba, Nemotron H), input processor for multi-modal data, logging/seed tweaks, and broad test coverage. Dynamic custom_ops discovery added. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant PrepareMetadata as PrepareMetadata (backend)
participant CachedOp as Cached Op (SSM/Conv)
participant Cache as Slot-State Cache
rect rgba(230,245,255,0.5)
note right of Caller: Invocation (flattened or generate-only)
Caller->>PrepareMetadata: (input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size)
PrepareMetadata-->>Caller: (seq_len_s, seq_start, slot_idx_s)
end
alt Generate-only
Caller->>Cache: Gather states by slot_idx_s
Caller->>CachedOp: (..., seq_len_s, seq_start, slot_idx_s, Cache, constants)
CachedOp->>Cache: Update states in-place per slot
CachedOp-->>Caller: y
else Flattened context
loop Per sequence segment
Caller->>CachedOp: Prefill on segment
CachedOp->>Cache: Writeback final state per slot
end
CachedOp-->>Caller: y (assembled)
end
sequenceDiagram
autonumber
participant Graph as Exported Graph Node
participant Descriptor as AttentionDescriptor
participant Init as Cache Initializers
participant Prepare as PrepareMetadata
participant Exec as Cached Execution
participant Buffers as Global Buffers
Graph->>Descriptor: get_source_attention_op()
Graph->>Descriptor: get_cached_attention_op()
Graph->>Descriptor: get_prepare_metadata_op()
Graph->>Descriptor: get_cache_initializers()
Descriptor-->>Init: initializers (e.g., ssm_state_cache / conv_state_cache)
Graph->>Init: allocate caches
Graph->>Descriptor: get_constants()
Descriptor-->>Buffers: constants/global buffers
Graph->>Prepare: prepare(..., slot_idx, ...)
Prepare-->>Graph: metadata
Graph->>Exec: call cached op with metadata, caches, constants
Exec-->>Graph: outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120+ minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 30
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (22)
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025) at top of file.Required by repo guidelines; place above the module docstring.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Testing build_and_run_ad end2end. """tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (3)
1-1: Add mandatory NVIDIA Apache-2.0 header (2025).File lacks the required license header per coding guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Graph transformation to automatically add kv cache into fused MHA op."""As per coding guidelines.
271-274: Fix units in log message.“Current num pages (MB)” is a count, not MB.
- self._log_info( - f"Current cache size (MB): {current_cache_size // 1024 // 1024}, " - f"Current num pages (MB): {current_num_pages}" - ) + self._log_info( + f"Current cache size (MB): {current_cache_size // 1024 // 1024}, " + f"Current num pages: {current_num_pages}" + )
298-300: Guard against divide-by-zero and use integer math in resize_kv_cache- new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size - new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) + if current_num_pages <= 0 or current_cache_size <= 0: + self._log_info( + "Skipping cache resize: cache not initialized (pages or size is zero). " + "Ensure 'initialize_cache' runs before 'resize_kv_cache'." + ) + return gm, TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + + new_free_bytes = int(free_mem_post * 1024 * 1024 * free_mem_ratio) + new_cache_size = new_free_bytes + current_cache_size + bytes_per_page = current_cache_size // current_num_pages + new_num_pages = new_cache_size // max(1, bytes_per_page)tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (3)
1-1: Add required NVIDIA Apache-2.0 header.This file is missing the mandatory copyright header.
As per coding guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
67-68: Logging format bug will raise at runtime.Passing args without a format specifier triggers logging’s percent-format error.
- ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks) + ad_logger.info("Using fake cache manager with head_dim=0 and num pages: %s", self.num_blocks)
238-244: Addslot_idxto allnest_sequencescalls
- shim/demollm.py (lines 113, 145)
- shim/ad_executor.py already updated
- custom_ops/attention_interface.py (around line 534)
- tests/unit/singlegpu/shim/test_engine.py (lines 78, 112)
- tests/unit/singlegpu/transformations/library/test_kv_cache.py (line 188)
- tests/_utils_test/_graph_test_helpers.py: ensure
super().nest_sequences(..., *args, **kwargs)will receiveslot_idxwhen overridden.tensorrt_llm/_torch/auto_deploy/shim/demollm.py (3)
77-83: Guard against exhausting free_pages and make allocation deterministic.Current set.pop() can raise KeyError when not enough pages remain; set iteration also makes assignments nondeterministic.
Apply:
- free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages} - updated_assignments = [] - for t_l, pages in zip(total_lens, page_assignments): - extra_tokens = t_l - len(pages) * si.page_size - num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0) - updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)]) + allocated = {i for pages in page_assignments for i in pages} + free_pages = sorted(i for i in range(si.num_pages) if i not in allocated) + updated_assignments = [] + for t_l, pages in zip(total_lens, page_assignments): + extra_tokens = t_l - len(pages) * si.page_size + num_extra_pages = (extra_tokens + si.page_size - 1) // si.page_size if extra_tokens > 0 else 0 + if num_extra_pages > len(free_pages): + raise RuntimeError( + f"Insufficient free pages: need {num_extra_pages}, have {len(free_pages)}; " + f"consider increasing num_pages or reducing batch/sequence length." + ) + new_pages = [free_pages.pop() for _ in range(num_extra_pages)] # deterministic (from end) + updated_assignments.append(pages + new_pages)
1-3: Add NVIDIA Apache-2.0 header (2025).File is missing the required license header.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.As per coding guidelines.
112-119: Persist slot_idx across nest_sequences and add Apache-2.0 license header
- Add the SPDX Apache-2.0 header at the top of tensorrt_llm/_torch/auto_deploy/shim/demollm.py.
- Before the first
sequence_info.nest_sequences, computeand replace both calls toslot_idx = list(range(len(input_ids)))and the second callsequence_info.nest_sequences(..., slot_idx=list(range(len(input_ids))), ...)withsequence_info.nest_sequences(token_ids, input_pos=..., page_assignments=...)sequence_info.nest_sequences(..., slot_idx=slot_idx, ...) sequence_info.nest_sequences(token_ids, input_pos=..., page_assignments=..., slot_idx=slot_idx)- If downstream backends expect a tensor, switch to
slot_idx = torch.arange(len(input_ids), dtype=torch.int32, device=...)tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
118-126: Critical: stub must register parent packages for import to succeedRegistering only mamba_ssm.ops.triton.layernorm_gated is insufficient; Python first imports parent packages. Create the full module chain and mark packages with path.
_mamba_ssm_module = "mamba_ssm" _mamba_ssm_submodule = f"{_mamba_ssm_module}.ops.triton.layernorm_gated" if importlib.util.find_spec(_mamba_ssm_module) is None: - stub_mod = types.ModuleType(_mamba_ssm_submodule) - stub_mod.rmsnorm_fn = _rms_norm_ref - sys.modules[_mamba_ssm_submodule] = stub_mod + # Build package chain: mamba_ssm, mamba_ssm.ops, mamba_ssm.ops.triton, and the target module. + pkg = types.ModuleType("mamba_ssm"); pkg.__path__ = [] # type: ignore[attr-defined] + ops = types.ModuleType("mamba_ssm.ops"); ops.__path__ = [] # type: ignore[attr-defined] + triton = types.ModuleType("mamba_ssm.ops.triton"); triton.__path__ = [] # type: ignore[attr-defined] + layernorm = types.ModuleType(_mamba_ssm_submodule) + layernorm.rmsnorm_fn = _rms_norm_ref + sys.modules["mamba_ssm"] = pkg + sys.modules["mamba_ssm.ops"] = ops + sys.modules["mamba_ssm.ops.triton"] = triton + sys.modules[_mamba_ssm_submodule] = layernormtests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (4)
60-61: Tests unconditionally select CUDA device — add skip-if-no-CUDA.These are “singlegpu” tests; hard-coding CUDA without a skip will fail on CPU-only runners.
@@ -from typing import Type, Union +from typing import Type, Union @@ -import pytest +import pytest @@ +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") @@ - device = torch.device("cuda") + device = torch.device("cuda") @@ - device = torch.device("cuda") + device = torch.device("cuda")Also applies to: 96-97
37-46: Reference-vs-engine logits compare uses different model instances — make them the same.You create two separately initialized models without reseeding, so RNG state diverges and logits may mismatch. Reuse a single model via a factory closure and fix the assertion message.
@@ -def get_inference_model(cache_seq_interface): - vocab_size = 128 - embed_dim = 32 - hidden_dim = 64 - device = "cuda" - - model = TransformerLikeModelwithFakeCachePool(vocab_size, embed_dim, hidden_dim) - model.eval().to(device) - return model +def make_get_inference_model(device: torch.device): + vocab_size = 128 + embed_dim = 32 + hidden_dim = 64 + model = TransformerLikeModelwithFakeCachePool(vocab_size, embed_dim, hidden_dim).eval().to(device) + def _get(_cache_seq_interface): + return model + return _get @@ - engine = engine_cls(get_inference_model, sequence_info, device) + get_model = make_get_inference_model(device) + engine = engine_cls(get_model, sequence_info, device) @@ - mock_input = None - original_logits = get_inference_model(mock_input)(input_ids[0].unsqueeze(0))[0] - assert torch.allclose(logits, original_logits, atol=1e-5), "Generated Token ID mismatch" + original_logits = get_model(None)(input_ids[0].unsqueeze(0))[0] + assert torch.allclose(logits, original_logits, atol=1e-5), "Logits mismatch with reference model"Also applies to: 71-72, 83-86
1-6: Add NVIDIA Apache-2.0 header (2025).Required by repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
48-51: Wire throughattn_backendintest_engineinstantiation
Intests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py, the callengine = engine_cls(get_inference_model, sequence_info, device)must include
attn_backend=attn_backend(or otherwise configure the backend per test) so each parametrized backend actually exercises a distinct code path.tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py (2)
118-127: Fix broken calls to Triton helpers: missing scale argument (runtime bug).triton_attention._generate_mha/_flattened_context_mha now require scale. Current calls pass y in the scale slot.
Apply:
+import math @@ - _generate_mha( - query_states.contiguous(), - key_states.contiguous(), - value_states.contiguous(), - k_cache, - v_cache, - cache_loc, - input_pos, - y, - ) + qk_head_dim_total = qk_nope_head_dim + qk_rope_head_dim + scale = (1.0 / math.sqrt(qk_head_dim_total)) if softmax_scale is None else softmax_scale + _generate_mha( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + k_cache, + v_cache, + cache_loc, + input_pos, + scale, + y, + ) @@ - _flattened_context_mha( - query_states.contiguous(), - key_states.contiguous(), - value_states.contiguous(), - input_pos, - cache_loc, - k_cache, - v_cache, - seq_len, - seq_start, - y, - ) + qk_head_dim_total = qk_nope_head_dim + qk_rope_head_dim + scale = (1.0 / math.sqrt(qk_head_dim_total)) if softmax_scale is None else softmax_scale + _flattened_context_mha( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + input_pos, + cache_loc, + k_cache, + v_cache, + seq_len, + seq_start, + scale, + y, + )This restores argument order and correct scaling.
Also applies to: 131-142
198-201: register_fake signature mismatch: slot_idx missing (breaks fake registration).Fake variant must mirror the real op signature.
Apply:
-@prepare_fused_mla_metadata.register_fake -def prepare_fused_mla_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size -): +@prepare_fused_mla_metadata.register_fake +def prepare_fused_mla_metadata_fake( + input_ids, + position_ids, + seq_len, + input_pos, + cache_loc, + pages_per_seq, + slot_idx, + page_size, +): - return ( - torch.empty_like(seq_len), - torch.empty_like(input_pos), - torch.empty_like(cache_loc), - torch.empty_like(seq_len), - ) + num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + return ( + torch.empty_like(seq_len[:num_seq]), + torch.empty_like(input_pos[:num_seq]), + torch.empty_like(cache_loc[:num_seq]), + torch.empty_like(seq_len[:num_seq]), + )Aligns with triton_attention fake and avoids shape inconsistencies.
tensorrt_llm/commands/serve.py (1)
1-1: Missing NVIDIA Apache-2.0 headerAdd the 2025 NVIDIA Apache-2.0 header per repo guidelines.
Apply at file top:
+# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py (1)
1-1: Missing NVIDIA Apache-2.0 headerAdd required header.
+# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)
1-1: Missing NVIDIA Apache-2.0 headerAdd required header.
+# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025).Per repo guidelines, prepend the standard NVIDIA Apache-2.0 copyright header to all .py files.
Apply at file top.
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
1-10: Add NVIDIA Apache-2.0 header (2025).Same header block as noted for other files, inserted above the module docstring. As per coding guidelines.
🧹 Nitpick comments (72)
tensorrt_llm/_torch/auto_deploy/llm_args.py (2)
197-203: Align field description with enforced behavior.Clarify defaulting vs. requirement for triton/torch backends.
Apply:
- description="Page size for attention (tokens_per_block). For triton and torch " - "backends, this should equal max_seq_len. Temporary field until tokens_per_block gets " - "properly passed through.", + description=( + "Page size for attention (tokens_per_block). For triton and torch backends, " + "this must equal max_seq_len. If unset, it defaults to max_seq_len. Temporary " + "field until tokens_per_block is plumbed end-to-end." + ),
206-213: Gate override ofattn_page_sizeand error on mismatchOnly force-set
attn_page_sizetomax_seq_lenfor"triton"/"torch"when not explicitly provided; otherwise raise if it doesn’t match:tensorrt_llm/_torch/auto_deploy/llm_args.py @model_validator(mode="after") # TODO: discuss what to do with this once we fully transition to the new inference optimizer def update_attn_page_size(self): - # NOTE force attn_page_size to equal max_seq_len for triton backend - # TODO: maybe don't do this and rely on slot_idx instead?? - if self.attn_backend == "triton" or self.attn_backend == "torch": - self.attn_page_size = self.max_seq_len + # Keep attn_page_size == max_seq_len for backends that require fixed tokens_per_block, + # but do not silently override explicit user input. + if self.attn_backend in ("triton", "torch"): + if "attn_page_size" not in self.model_fields_set: + self.attn_page_size = self.max_seq_len + elif self.attn_page_size != self.max_seq_len: + raise ValueError( + "attn_page_size must equal max_seq_len for attn_backend in {'triton','torch'}." + ) return selfAll existing tests (including default flashinfer/triton logic) remain green under this change.
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (3)
52-53: free_mem_ratio: clarify intent and use 1e-4 literal.Tiny ratio is fine for CI; add a brief comment and prefer scientific literal for readability.
compile_backend="torch-opt", - free_mem_ratio=0.0001, + # Keep memory footprint minimal for CI to avoid OOM on single-GPU runners. + free_mem_ratio=1e-4,
98-102: Guard Nemotron-Nano (flashinfer) with a skip-if when flashinfer is unavailable.Prevents spurious CI failures on environments without flashinfer.
get_small_model_config_pytest_param( "nvidia/NVIDIA-Nemotron-Nano-12B-v2", + pytest_param_kwargs={ + "marks": pytest.mark.skipif( + not HAS_FLASHINFER, reason="flashinfer not available" + ) + }, attn_backend="flashinfer", compile_backend="torch-simple", ),Add the following helper near the imports (once per file):
# at top-level, after imports try: import flashinfer # noqa: F401 HAS_FLASHINFER = True except Exception: HAS_FLASHINFER = FalseConsider applying the same skip-if to other flashinfer-backed params in this file for consistency.
Based on learnings
106-112: Clarify operator precedence and skip semantics (transformers-only vs always).Current expression relies on
and>orand yields:
- Always skip DeepSeek-V3 and Phi-3 (both modes).
- Skip Nemotron-Nano only in transformers mode.
If that’s intended, make it explicit and fix the reason string; else, applymode == "transformers"to all three.Recommended (mode-gated for all three):
- if ( - "DeepSeek-V3" in experiment_config["args"]["model"] - or "Phi-3-mini-4k-instruct" in experiment_config["args"]["model"] - or "NVIDIA-Nemotron-Nano-12B-v2" in experiment_config["args"]["model"] - and mode == "transformers" - ): + model_id = experiment_config["args"]["model"] + if mode == "transformers" and any( + name in model_id + for name in ("DeepSeek-V3", "Phi-3-mini-4k-instruct", "NVIDIA-Nemotron-Nano-12B-v2") + ): pytest.skip(f"{experiment_config['args']['model']} is not supported in transformers mode")Alternative (keep current semantics but avoid precedence pitfalls and fix reason when mode != "transformers"):
- if ( - "DeepSeek-V3" in experiment_config["args"]["model"] - or "Phi-3-mini-4k-instruct" in experiment_config["args"]["model"] - or "NVIDIA-Nemotron-Nano-12B-v2" in experiment_config["args"]["model"] - and mode == "transformers" - ): - pytest.skip(f"{experiment_config['args']['model']} is not supported in transformers mode") + model_id = experiment_config["args"]["model"] + if ( + "DeepSeek-V3" in model_id + or "Phi-3-mini-4k-instruct" in model_id + or ("NVIDIA-Nemotron-Nano-12B-v2" in model_id and mode == "transformers") + ): + reason = "transformers mode" if mode == "transformers" else "this test configuration" + pytest.skip(f"{model_id} is not supported in {reason}")Please confirm the intended behavior.
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (2)
286-294: Run forward in inference_mode and sync to stabilize memory readings.Prevents grad allocations and ensures mem_get_info sees settled allocations.
- self._run_forward(gm, cm) + with torch.inference_mode(): + self._run_forward(gm, cm) + torch.cuda.synchronize()
10-12: Good switch to instance-scoped logging; consider logging final size only on rank 0.Prevents N-way duplicated logs in multi-GPU runs.
-from ...distributed.common import all_gather_object, get_world_size +from ...distributed.common import all_gather_object, get_world_size, get_rank @@ - self._log_info( - f"Final KV cache size after resize: {final_cache_size_gb:.2f} GiB ({new_num_pages} pages)" - ) + if get_rank() == 0: + self._log_info( + f"Final KV cache size after resize: {final_cache_size_gb:.2f} GiB ({new_num_pages} pages)" + )Also applies to: 312-314
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (4)
161-161: Avoid fixed global seed; make per-rank and seed CUDA too.A fixed global seed can produce identical sampling across ranks and ignores CUDA RNG.
- torch.manual_seed(42) + # Per-rank seed; also seed CUDA RNGs + _seed = 42 + mpi_rank() + torch.manual_seed(_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(_seed)Would you prefer making the base seed configurable via LlmArgs/AutoDeployConfig?
185-187: slot_idx collection looks good; add a quick invariant check before nesting.Keeps inputs aligned and catches scheduler mismatches early.
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) @@ # update the sequence info object now + # Sanity: ensure alignment of per-sequence lists + assert len(slot_idx) == len(input_ids) == len(input_pos) == len(page_assignments), ( + f"len mismatch: slot_idx={len(slot_idx)}, input_ids={len(input_ids)}, " + f"input_pos={len(input_pos)}, pages={len(page_assignments)}" + )
204-206: Guard seq_slot presence for context requests.Make the contract explicit; avoids None/invalid types propagating into nesting.
- slot_idx.append(request.seq_slot) + assert request.seq_slot is not None, "seq_slot must be set for context requests" + slot_idx.append(int(request.seq_slot))
226-229: Clarify overlap-scheduler semantics for gen path; assert seq_slot.Given the TODO, add a minimal guard now; verify that slot_idx should reflect the current (post-schedule) slot.
- # TODO: double-check if this is correct for the overlap scheduler - slot_idx.append(request.seq_slot) + # TODO: confirm correctness for the overlap scheduler; slot_idx should map to current slot + assert request.seq_slot is not None, "seq_slot must be set for generation requests" + slot_idx.append(int(request.seq_slot))Would you like a focused unit test covering overlap scheduling with rescatter?
tensorrt_llm/_torch/auto_deploy/shim/demollm.py (1)
145-149: Pass slot_idx on subsequent nest_sequences calls during decode.Maintain the same slot mapping across steps to match cache/state expectations.
Apply:
- sequence_info.nest_sequences( + sequence_info.nest_sequences( token_ids, input_pos=input_pos_next, - page_assignments=self._assign_pages(total_lens_next), + page_assignments=self._assign_pages(total_lens_next), + slot_idx=slot_idx, )tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (6)
7-10: Avoid hard dependency on einops (optional)_rms_norm_ref only needs simple reshapes; consider replacing einops with view/reshape to drop an extra dependency or lazy-import it inside the fallback path.
21-45: RMSNorm: add shape checks and prefer rsqrt; optional: drop einops
- Guard group_size divisibility and optional z shape.
- Use torch.rsqrt for better numerics.
def _rms_norm_ref( x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True ): - dtype = x.dtype - # N = x.shape[-1] + dtype = x.dtype + hidden = x.shape[-1] weight = weight.float() bias = bias.float() if bias is not None else None if upcast: x = x.float() z = z.float() if z is not None else z if z is not None and not norm_before_gate: x = x * F.silu(z) if group_size is None: - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + rstd = torch.rsqrt((x.square()).mean(dim=-1, keepdim=True) + eps) out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) else: - x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) - out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if hidden % group_size != 0: + raise ValueError(f"group_size ({group_size}) must divide last dim ({hidden})") + # If einops stays: keep rearrange; otherwise use reshape: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = torch.rsqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight if bias is not None: out = out + bias if z is not None and norm_before_gate: out *= F.silu(z) return out.to(dtype)
49-56: Silence unused-arg warnings on stubbed mask updatersThese are intentionally no-ops; add noqa or consume args to appease linters.
-def _nemotron_h_model_update_mamba_mask(self, attention_mask, cache_position): +def _nemotron_h_model_update_mamba_mask(self, attention_mask, cache_position): # noqa: ARG001 return None @@ -def _nemotron_h_model_update_causal_mask(self, attention_mask, input_tensor, cache_position): +def _nemotron_h_model_update_causal_mask(self, attention_mask, input_tensor, cache_position): # noqa: ARG001 # Force attention to use causal mode without explicit masks return None
103-111: Minor: avoid .keys() and micro‑cleanupUse membership directly; also consider guarding repeat patching.
- for _, module in model.named_modules(): - if (module_name := type(module).__name__) in CUSTOM_MODULE_PATCHES.keys(): + for _, module in model.named_modules(): + if (module_name := type(module).__name__) in CUSTOM_MODULE_PATCHES: patches = CUSTOM_MODULE_PATCHES[module_name] for method_name, method_patch in patches: - setattr(module, method_name, types.MethodType(method_patch, module)) + setattr(module, method_name, types.MethodType(method_patch, module))
115-117: Global monkey‑patch risk: confine scopeOverriding transformers.AutoModelForCausalLM.from_config globally can affect unrelated code in‑process and tests. Gate behind an env flag or provide a scoped context manager.
-AutoModelForCausalLM.from_config = get_model_from_config_patched +import os +if os.getenv("TLLM_ENABLE_NEMOTRON_PATCH", "1") == "1": + AutoModelForCausalLM.from_config = get_model_from_config_patched
80-85: Optional: use a dedicated exception typeTRY003 flagged the ValueError message; if this is a hot path, consider a small custom exception with a fixed message to satisfy the rule. Otherwise, ignore.
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (3)
55-59: Seed change is fine; centralize seeding for full reproducibility.Use a small helper to seed PyTorch (and CUDA) consistently; call it in both tests to avoid drift.
@@ - seed = 42 # Set random seed for model param init - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + seed = 42 # Set random seed for model param init + set_seed(seed)Add once near imports:
@@ +def set_seed(seed: int) -> None: + import random + random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed)
14-25: Nit: class name casing.Consider TransformerLikeModelWithFakeCachePool for readability.
88-95: Minor: unify seeding and keep sampling deterministic checks minimal.Use set_seed(seed) for this test too; current assertions are shape/range only, so no need to force global determinism beyond that.
@@ - seed = 0 - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + set_seed(0)Also applies to: 116-137
tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py (4)
22-28: Silence intentional unused-arg warnings without changing signatures.These functions must keep their signatures for compatibility, but several parameters are intentionally unused.
Apply inline noqa markers on the def lines:
-def _bamba_mixer_torch_forward( +def _bamba_mixer_torch_forward( # noqa: ARG001 self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -def _bamba_model_update_mamba_mask(self, attention_mask, cache_position): +def _bamba_model_update_mamba_mask(self, attention_mask, cache_position): # noqa: ARG001 return None @@ -def _bamba_model_update_causal_mask( +def _bamba_model_update_causal_mask( # noqa: ARG001 self, attention_mask, input_tensor, cache_position, past_key_values, output_attentions, ): @@ -def _cache_bool(self) -> bool: +def _cache_bool(self) -> bool: # noqa: ARG001 return TrueBased on coding guidelines.
Also applies to: 149-151, 153-163, 179-187
46-51: Ensure metadata dtypes/shapes match custom op contracts; cast explicitly.seq_len/seq_start are created with torch.int (int32) while slot_idx is int64. Please confirm the op signatures; mis‑typed metadata will silently upcast on CPU but can fail AOT/export or CUDA kernels.
Proposed explicit casts (adjust if your op expects different types):
- seq_len_t = torch.full((batch_size,), seq_len, device=input_states.device, dtype=torch.int) - seq_start_t = torch.arange( - 0, batch_size * seq_len, seq_len, device=input_states.device, dtype=torch.int - ) - slot_idx_t = torch.arange(batch_size, device=input_states.device, dtype=torch.long) + seq_len_t = torch.full( + (batch_size,), seq_len, device=input_states.device, dtype=torch.int32 + ) + seq_start_t = torch.arange( + 0, batch_size * seq_len, seq_len, device=input_states.device, dtype=torch.int32 + ) + slot_idx_t = torch.arange( + batch_size, device=input_states.device, dtype=torch.int64 + )Also, if time_step_limit can be None, guard its conversion to list:
- time_step_limit=list(self.time_step_limit), + time_step_limit=(list(self.time_step_limit) if self.time_step_limit is not None else []),Please verify against the C++/Triton op signatures and adjust accordingly.
Also applies to: 52-72, 101-121
135-136: Prefer reshape over view after custom ops.y may be non‑contiguous; view can error. reshape is safer.
- y = y.view(batch_size, seq_len, -1) + y = y.reshape(batch_size, seq_len, -1)
7-15: Optional: avoid importing concrete classes directly from deep paths.To reduce fragility against upstream refactors, consider importing the module and referencing names off it.
Example:
-from transformers.models.bamba.modeling_bamba import ( - BambaMixer, - BambaModel, - BambaPreTrainedModel, - HybridMambaAttentionDynamicCache, - apply_mask_to_padding_states, -) +from transformers.models.bamba import modeling_bamba as bamba_mod +from transformers.models.bamba.modeling_bamba import apply_mask_to_padding_states +# Usage: bamba_mod.BambaMixer, bamba_mod.BambaModel, ...As per coding guidelines.
tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025) at top of file.Repository guideline requires header on all .py files.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.As per coding guidelines.
184-186: Silence unused slot_idx parameter (Ruff ARG001).Parameter is intentionally unused; rename to underscore to avoid lint noise.
- slot_idx: torch.Tensor, + _slot_idx: torch.Tensor,tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (2)
1-1: Add NVIDIA Apache-2.0 header.+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); see LICENSE file for details.As per coding guidelines.
8-10: Rename unused loop variable to underscore.Small lint cleanup; functional no-op.
-for _, module_name, is_pkg in pkgutil.iter_modules(__path__): +for _, module_name, _is_pkg in pkgutil.iter_modules(__path__):tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (2)
1-1: Add NVIDIA Apache-2.0 header.+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); see LICENSE file for details.As per coding guidelines.
293-295: Silence unused metadata params (Ruff ARG001).slot_idx/page_size/position_ids/pages_per_seq are intentionally unused here.
-def prepare_fused_mha_metadata( +def prepare_fused_mha_metadata( @@ - slot_idx: torch.Tensor, - page_size: int, + _slot_idx: torch.Tensor, + _page_size: int, @@ -def prepare_fused_mha_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size +def prepare_fused_mha_metadata_fake( + input_ids, _position_ids, seq_len, input_pos, cache_loc, _pages_per_seq, _slot_idx, _page_size ):Also applies to: 312-313
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (3)
1-1: Add NVIDIA Apache-2.0 header.+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); see LICENSE file for details.As per coding guidelines.
12-12: Remove unused noqa; keep side-effect import comment.-import tensorrt_llm._torch.auto_deploy # noqa: F401 +import tensorrt_llm._torch.auto_deploy # side-effect import to register custom ops
88-88: Drop unused E501 noquas.Ruff flagged non-enabled noqa directives.
- tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_decode( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_decode( # type: ignore @@ - tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignoreAlso applies to: 151-151
tensorrt_llm/commands/serve.py (1)
37-71: Signal handler performs non–async-signal-safe work (logging, waits) and exits the parentLogging and multi-step process control inside a signal handler can be unsafe; also the unconditional sys.exit may tear down the parent in contexts where cleanup should happen in finally.
Minimize the handler to set a flag and send a signal to child; move waits/logging to normal control flow (e.g., check flag around blocking sections). At minimum, replace logger calls with os.write to stderr or defer logs.
tensorrt_llm/_torch/auto_deploy/llm.py (1)
46-48: Over‑strict assert on “multi_modal_data” with “messages”Asserting absence can break valid callers; prefer ignoring or merging multimodal payload.
Replace assert with a guard/warning or merge behavior. Example:
- assert "multi_modal_data" not in inputs, f"unexpected multi_modal_data key in {inputs=}" + # If present, multi_modal_data will be handled by processor; ignore here.tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (2)
365-372: slot_idx accepted but unused — silence lints and documentRuff flags unused args. Keep signature but explicitly discard to avoid warnings.
def torch_backend_prepare_metadata( @@ - slot_idx: torch.Tensor, + slot_idx: torch.Tensor, page_size: int, ) -> List[torch.Tensor]: """Prepare metadata for torch backend attention (similar to triton backend).""" + # Note: slot_idx and page_size are reserved for future use in torch backend. + # Keep signature stable; explicitly mark as unused. + _ = (slot_idx, page_size)
380-390: Fake variant: discard unused parameters to appease lintersdef torch_backend_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): + _ = (position_ids, pages_per_seq, slot_idx, page_size) num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (4)
156-166: slot_idx unused in prepare_flashinfer_metadata; keep signature but silence linters.The backend doesn't need slot_idx (uniform signature across ops), but Ruff flags it.
Use a throwaway name to document intent:
- pages_per_seq: torch.Tensor, - slot_idx: torch.Tensor, + pages_per_seq: torch.Tensor, + _slot_idx: torch.Tensor,
215-228: Fake variant: multiple unused args; rename to underscore to avoid ARG001.No functional change; keeps signatures aligned.
-@prepare_flashinfer_metadata.register_fake -def prepare_flashinfer_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size -): +@prepare_flashinfer_metadata.register_fake +def prepare_flashinfer_metadata_fake( + input_ids, _position_ids, seq_len, _input_pos, cache_loc, _pages_per_seq, _slot_idx, _page_size +):
418-421: Scale type guard is too strict; accept ints/real numbers.Some callers may pass an int or NumPy scalar. Treat any real number as valid.
-from typing import ... +import numbers ... - if not (isinstance(scale, float) or scale is None): + if not (isinstance(scale, numbers.Real) or scale is None): ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") scale = None
278-283: Shadowing function args k_scale/v_scale; use locals to avoid confusion.You immediately overwrite k_scale/v_scale, then pass them to wrapper.run. Use local names to make intent explicit and prevent accidental reliance on call-site constants.
- # Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache - k_scale, v_scale = 1.0, 1.0 + # Assuming fixed scales of 1.0 for FP8 cache append + k_scale_local, v_scale_local = 1.0, 1.0 if k_cache.dtype == torch.float8_e4m3fn: - k = (k / k_scale).to(torch.float8_e4m3fn) - v = (v / v_scale).to(torch.float8_e4m3fn) + k = (k / k_scale_local).to(torch.float8_e4m3fn) + v = (v / v_scale_local).to(torch.float8_e4m3fn) ... - y = wrapper.run(q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale) + y = wrapper.run(q, (k_cache, v_cache), k_scale=k_scale_local, v_scale=v_scale_local)If dynamic scaling becomes needed, consider plumbing per‑tensor scales and removing these constants from get_constants.
Also applies to: 303-305
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (2)
12-12: Remove unused noqa directive.Ruff flags RUF100 since F401 isn’t enabled. Keep the import (side-effect registration), but drop the noqa.
-import tensorrt_llm._torch.auto_deploy # noqa: F401 +import tensorrt_llm._torch.auto_deployOptionally add a clarifying comment: “import for side-effects: operator registration.”
153-156: Remove unused noqa E501.Ruff indicates the directive is unused; line fits current config.
- tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignoretests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (2)
4-4: Remove unused noqa directive.Keep the import for side-effects; drop
# noqa: F401per RUF100.-import tensorrt_llm._torch.auto_deploy # noqa: F401 +import tensorrt_llm._torch.auto_deploy
166-170: Rename unused loop variable to underscore.Silences Ruff B007 and clarifies intent.
- for i, ln in enumerate(lens): + for i, _ln in enumerate(lens):tensorrt_llm/_torch/auto_deploy/custom_ops/torch_causal_conv.py (2)
20-21: Use explicit error instead of assert forpadding_mode.Asserts can be compiled out; prefer a runtime check.
- assert padding_mode == "zeros", "padding_mode must be zeros" + if padding_mode != "zeros": + raise ValueError("padding_mode must be 'zeros'")
22-22: Silence unused variable to satisfy linters.- batch_size, seq_len, _ = input.shape + _, seq_len, _ = input.shapetests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py (2)
1-1: Remove unnecessary# noqadirectives.They hide useful linting and are not needed here.
-import torch # noqa +import torch -from torch.export import Dim # noqa +from torch.export import Dim -from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device # noqa +from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_deviceAlso applies to: 3-3, 8-8
66-67: Hard-coded CUDA usage can fail on CPU envs; use detected device.Tests live in singlegpu, but device plumbing is trivial to make robust.
- inputs = tokenizer(message, return_tensors="pt", return_token_type_ids=False).to("cuda") + dev = "cuda" if torch.cuda.is_available() else "cpu" + inputs = tokenizer(message, return_tensors="pt", return_token_type_ids=False).to(dev)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (2)
9-13: Skip module if CUDA unavailable.Avoid failures on CPU CI.
import pytest import torch import tensorrt_llm._torch.auto_deploy # noqa: F401 + +# Skip all tests here if CUDA is unavailable +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
12-12: Remove redundant# noqamarkers.They refer to non-enabled rules and aren’t needed.
-import tensorrt_llm._torch.auto_deploy # noqa: F401 +import tensorrt_llm._torch.auto_deploy - tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_mamba._torch_cached_ssm_transform_decode( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_mamba._torch_cached_ssm_transform_decode( # type: ignore - tensorrt_llm._torch.auto_deploy.custom_ops.torch_mamba._torch_ssm_transform_prefill( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.torch_mamba._torch_ssm_transform_prefill( # type: ignoreAlso applies to: 98-98, 171-171
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py (1)
165-183: Op wiring and meta path look consistent with prefill returning fp32.No functional concerns; consider brief Google-style docstrings on public custom ops for clarity (args/dtypes/shapes).
Also applies to: 185-197
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py (2)
125-156: Clamp/softplus path is sound; minor perf nit.dt_bias_zero is constructed but unused by kernel (dt_softplus=False). Safe, but you can avoid allocating it by passing None if selective_state_update permits optional dt_bias when dt_softplus=False.
224-249: Cache initializer: simplify ssm_state_size detection.The else branch max(1, B_fake.shape[-1]) is redundant; shape[-1] already ≥1. Also guard for dtype/device consistency from cache_config.
- if B_fake.ndim >= 4: - ssm_state_size = B_fake.shape[-1] - else: - ssm_state_size = max(1, B_fake.shape[-1]) + ssm_state_size = B_fake.shape[-1]tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5)
337-348: Pattern 6 passes enable_gqa=True in search; may miss graphs without this kwarg.If upstream sdpa.default omits enable_gqa (defaults), the pattern may not match. Consider registering a variant without enable_gqa or ignore this kwarg during matching.
- register_ad_pattern( + # Variant that ignores the 'enable_gqa' kwarg to improve match robustness + register_ad_pattern( search_fn=_grouped_attn_pattern_6, replace_fn=_grouped_attn_replacement_6, patterns=patterns, dummy_args=dummy_args_2, - scalar_workaround={"scale": scale, "dropout_p": dropout}, + scalar_workaround={"scale": scale, "dropout_p": dropout}, + op_ignore_types={torch.ops.auto_deploy.torch_attention_sdpa.default: (bool,)}, )
356-367: Same as above for Pattern 7 (causal).Apply the same op_ignore_types mapping as in Pattern 6 to avoid brittle matches.
375-392: Patterns 8/9 omit attn_mask and add enable_gqa; add robust variant or ignore bools.Mirror the op_ignore_types addition here too to match callsites lacking the kwarg.
413-435: Replacement 10 ignores n_rep (benign) — suppress lint or document.n_rep is unused in replacement_10 by design (grouped op handles repetition). Add a comment or underscore the param to avoid ARG001 noise.
-def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p): +def _grouped_attn_replacement_10(q, k, v, _n_rep, dropout_p): return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
668-672: MatchAttentionLayout only targets grouped_sdpa.If graphs still contain torch_attention_sdpa at this stage, they won't be transposed for bsnd. Consider including sdpa as well or document the pass ordering guarantee.
tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py (3)
139-145: Weight shape handling: assert depthwise contract or document supported cases.You assume weight.ndim==3 with shape[-2]==1 or 2D [C,K]. If other layouts sneak in, assert early with a clear error.
- if weight.ndim == 3: + if weight.ndim == 3: assert weight.shape[-2] == 1 w2d = weight.squeeze(-2) - else: + elif weight.ndim == 2: w2d = weight + else: + raise ValueError(f"Unsupported weight shape {tuple(weight.shape)}; expected [C,K] or [C,1,K].")
288-293: Constants are ignored by CUDA path; validate defaults to avoid semantic drift.If source op uses non-default stride/padding/dilation/groups/padding_mode, cached CUDA path will not honor them. Either plumb them or guard with a check that they are default values.
- stride, padding, dilation, groups, padding_mode = extract_op_args( + stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) + # Optional: enforce supported defaults + if (stride != 1) or (padding != 0) or (dilation != 1) or (groups != 1) or (padding_mode not in (None, "zeros")): + ad_logger.warning("cuda_cached_causal_conv1d does not honor non-default conv params; results may differ.") return [stride, padding, dilation, groups, padding_mode]
36-50: Unused helper _build_conv_state_from_sequence.If not needed, drop it to reduce surface area.
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py (5)
98-106: Decode path constraints not enforced by constants; document/guard in metadata.Decode asserts stride==1 and dilation==1. If source op is configured differently, cached path will raise at runtime.
- Consider validating these constants once in prepare/transform pass and fail early.
- Alternatively, specialize kernel/constants so prefill uses provided stride/dilation but decode enforces 1.
If you want, I can add a compile-time check in the pass that inserts cached ops.
Also applies to: 101-103
65-76: Unused locals (batch_size, seq_len).Minor lint from Ruff (RUF059). Replace with underscores to avoid warnings.
- batch_size, seq_len, _ = input.shape + _batch_size, _seq_len, _ = input.shape
104-116: Unused local (batch_size).Same nit in decode.
- batch_size, seq_len, _ = input.shape + _batch_size, _seq_len, _ = input.shape
112-121: Cache update via roll is O(K). Consider deque-like assignment.Rolling shifts the entire [C_in, K] window every token. For large K this is avoidable.
Use a circular index tracked per slot or maintain a write pointer buffer; if that’s out of scope now, keep as-is and revisit for perf.
300-310: get_num_qkv_args=3 with optional bias; ensure callsites pass None safely.If bias is None, confirm the replacement pass supplies an Optional[Tensor] acceptable by custom_op schema; otherwise wrap a registered None-like tensor.
I can add a small adapter in the transform to coerce None to a zero bias tensor of correct dtype/device.
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
663-666: Zero out stale tail of slot_idx for fewer-than-max batches._seq_len is reset, but slot_idx isn’t. While sanitized consumers slice to num_seq, zeroing avoids accidental misuse elsewhere.
- if slot_idx is not None: - self._store_arg("slot_idx", slot_idx) + if slot_idx is not None: + self._store_arg("slot_idx", slot_idx, reset=True)
427-466: Return types in sanitization helpers: ensure Python ints for slicing._get_sanitized_num_sequences returns a Tensor when s > 1; later code uses it in slicing. Convert to int to avoid brittle behavior.
- if s > 1: - num_seq = torch.sum(seq_len > 0) + if s > 1: + num_seq = int(torch.sum(seq_len > 0).item())Do the same wherever num_seq is produced and used as a slice bound. Based on learnings.
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py (2)
31-43: Decode helper’s chunk_size is unused.Lint noise and possible confusion.
- chunk_size: int, + _chunk_size: int,Alternatively, remove the parameter if not needed.
105-107: Unused helper _update_ssm_state_cache.Not referenced; remove to reduce surface area.
-def _update_ssm_state_cache(ssm_cache: torch.Tensor, ssm_state: torch.Tensor) -> None: - ssm_cache.copy_(ssm_state) +# Removed unused helper _update_ssm_state_cache
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py
Show resolved
Hide resolved
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #20284 [ run ] triggered by Bot |
|
PR_Github #20284 [ run ] completed with state |
|
/bot run |
|
PR_Github #20306 [ run ] triggered by Bot |
|
PR_Github #20306 [ run ] completed with state |
…l_conv + Bamba + Nemotron-H) (NVIDIA#8068) Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Signed-off-by: Faradawn Yang <faradawny@gmail.com>
Summary by CodeRabbit
New Features
Bug Fixes
Refactor/Chores
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.