-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][chore] Wrap the swiglu into custom op to avoid redundant device copy. #7021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughAdds a new Torch custom op Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant SwiGLU as SwiGLU.forward
participant TorchOp as torch.ops.trtllm.silu_and_mul
participant Triton as silu_and_mul_kernel (Triton)
SwiGLU->>TorchOp: silu_and_mul(x[, scale, dtype])
alt scale provided
Note right of TorchOp #DDEBF7: HAS_O_SCALE = True
else no scale
Note right of TorchOp #F7F0DD: HAS_O_SCALE = False
end
TorchOp->>Triton: launch grid (B, triton.cdiv(D, BLOCK_SIZE)) with pointers/strides and flags
Triton-->>TorchOp: returns output tensor o (B, N/2)
TorchOp-->>SwiGLU: o
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/swiglu.py (1)
54-58: All existingswiglucalls use torch.dtype – optional guard for Triton dtypes
All current invocations ofswiglu(x, quant_scale, quant_type)pass atorch.dtype(orNone), so no runtime dtype mismatches will occur. If you’d like to future-proof against someone accidentally passing a Triton dtype (e.g.tl.float8e4nvortl.float8e5), you can optionally add:• File:
tensorrt_llm/_torch/modules/swiglu.py(around lines 51–58)
• Wrap the custom-op call’sdtypeargument in a small converter:@@ def swiglu(x, quant_scale: torch.Tensor = None, quant_type=None): - return torch.ops.trtllm.silu_and_mul( - x, - scale=quant_scale, - dtype=quant_type, - ) + return torch.ops.trtllm.silu_and_mul( + x, + scale=quant_scale, + dtype=_as_torch_dtype(quant_type), + )And add this helper at the bottom of the module:
def _as_torch_dtype(dt): import triton.language as tl, torch if dt is None or isinstance(dt, torch.dtype): return dt # Map Triton FP8→PyTorch FP8 if dt is getattr(tl, "float8e4nv", None): return getattr(torch, "float8_e4m3fn", torch.float16) if dt is getattr(tl, "float8e5", None): return getattr(torch, "float8_e5m2", torch.bfloat16) return dtNo changes are required for correctness today—this is purely an optional refactor.
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (3)
913-940: Handle non-contiguous inputs or enforce contiguity explicitly.The Triton kernel indexes the last dimension as if stride(1) == 1. If x is a non-contiguous view (e.g., sliced/transposed), the loads will read wrong elements. Either assert contiguity or make a safe fallback.
Apply this guard to avoid silent misreads (rare path does a copy; common path stays zero-copy):
def silu_and_mul(x: torch.Tensor, scale: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: b, n = x.shape assert n % 2 == 0 d = n // 2 + # Kernel assumes contiguous last-dimension indexing. + if x.stride(1) != 1: + x = x.contiguous() + o_dtype = dtype or x.dtype o = torch.empty((b, d), dtype=o_dtype, device=x.device)If you prefer to avoid the copy entirely, make it a hard check:
if x.stride(1) != 1: raise ValueError("trtllm::silu_and_mul expects x to be contiguous along the last dimension")
919-921: Prefer a ValueError over assert for shape checks.Using assert for input validation can be stripped with Python optimizations, leading to hard-to-debug failures downstream. Raise an explicit ValueError instead.
- assert n % 2 == 0 + if (n % 2) != 0: + raise ValueError(f"trtllm::silu_and_mul expects even last dim; got n={n}")
913-955: File header missing: add NVIDIA copyright header.Our guidelines require the NVIDIA copyright header (current year) on all source files.
Add at the top of the file:
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(2 hunks)tensorrt_llm/_torch/modules/swiglu.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/modules/swiglu.pytensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/modules/swiglu.pytensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/modules/swiglu.py (1)
60-60: LGTM: delegating to the custom op in the non-quantized path.The zero-arg variant cleanly calls the fused op without scale/dtype and avoids redundant device copies.
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
13-13: Kernel import location looks good.Importing silu_and_mul_kernel from modules.swiglu maintains a single source of truth for the kernel while exposing a clean custom op API here.
|
PR_Github #15725 [ run ] triggered by Bot |
|
PR_Github #15725 [ run ] completed with state |
eb27b6d to
ffb3f9e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #15839 [ run ] triggered by Bot |
|
PR_Github #15839 [ run ] completed with state |
ffb3f9e to
0d0c6b4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
704-714: Bug: variable name mismatch in w4a8_mxfp4_fp8_gemm.register_fake (act_fp8 vs act_fp4).The function parameters are
act_fp4but the return usesact_fp8, which will raise NameError.@w4a8_mxfp4_fp8_gemm.register_fake def _( - act_fp4: torch.Tensor, + act_fp8: torch.Tensor, weight: torch.Tensor, act_sf: torch.Tensor, weight_scale: torch.Tensor, alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool = False, ) -> torch.Tensor: - return act_fp8.new_empty((act_fp8.size(0), weight.size(0)), + return act_fp8.new_empty((act_fp8.size(0), weight.size(0)), dtype=output_dtype)If the intent was
act_fp4, then consistently rename usages instead:- return act_fp8.new_empty((act_fp8.size(0), weight.size(0)), + return act_fp4.new_empty((act_fp4.size(0), weight.size(0)), dtype=output_dtype)
♻️ Duplicate comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
925-927: Python 3.8 compatibility: use typing.Tuple instead of built-in tuple generics.This is a SyntaxError on 3.8.
- def grid(meta: Mapping[str, int]) -> tuple[int, int]: + def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
🧹 Nitpick comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)
1-1: Add NVIDIA copyright header (2025) per repo guidelines.The file is missing the required NVIDIA copyright header.
Apply at the very top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
5-5: Optional: drop top-level Triton import by avoiding triton.cdiv in grid.This keeps import-time dependencies lighter (helpful for fake/meta paths), and the integer expression is equivalent.
-import triton # type: ignore[import] @@ - def grid(meta: Mapping[str, int]) -> Tuple[int, int]: - return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) + def grid(meta: Mapping[str, int]) -> Tuple[int, int]: + # ceil_div(d, BLOCK_SIZE) + return (b, (d + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"])Also applies to: 925-927
13-13: Prefer module-namespace import to avoid symbol-level coupling.Keeps imports consistent with the guideline to maintain module namespaces, and reduces risk of future name collisions.
-from ..modules.swiglu import silu_and_mul_kernel +from ..modules import swiglu as _swiglu @@ - silu_and_mul_kernel[grid]( + _swiglu.silu_and_mul_kernel[grid]( o_ptr=o, o_stride=o.stride(0), o_scale_ptr=scale, x_ptr=x, x_stride=x.stride(0), d=d, BLOCK_SIZE=1024, HAS_O_SCALE=scale is not None, )Also applies to: 928-938
942-955: Meta (fake) impl: align with real op’s validation and docstring.Mirror minimal validation and add a short docstring. This helps shape errors surface consistently under torch.compile.
@silu_and_mul.register_fake def _( x: torch.Tensor, scale: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - b, n = x.shape - - assert n % 2 == 0 - d = n // 2 + """Meta for silu_and_mul: returns an empty tensor with correct shape/dtype.""" + if x.dim() != 2: + raise ValueError(f"silu_and_mul (fake): expected 2D input [B, N], got {tuple(x.shape)}") + b, n = x.shape + if n % 2 != 0: + raise ValueError(f"silu_and_mul (fake): expected even N, got N={n}") + d = n // 2 o_dtype = dtype or x.dtype return x.new_empty((b, d), dtype=o_dtype)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(2 hunks)tensorrt_llm/_torch/modules/swiglu.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
2-2: Typing imports look good.Importing Mapping/Optional/Tuple prepares for Python 3.8 compatibility fixes below.
913-955: The script will list the imports and Python implementation offlashinfer_silu_and_mul. This will clarify whether it’s still calling an older, non-ops-basedsilu_and_mulor correctly dispatching to the custom‐op. Once we see that, we can determine if the review comment is resolved or if further changes are needed.
0d0c6b4 to
e5e2c62
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #15987 [ run ] triggered by Bot |
|
PR_Github #15987 [ run ] completed with state |
e5e2c62 to
5b4468d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #16121 [ run ] triggered by Bot |
|
PR_Github #16121 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #16159 [ run ] triggered by Bot |
|
/bot run --disable-fail-fast |
|
PR_Github #16235 [ run ] triggered by Bot |
|
PR_Github #16235 [ run ] completed with state |
|
/bot run --disable-fail-fast |
5b4468d to
43ae97e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #16342 [ run ] triggered by Bot |
|
PR_Github #16344 [ run ] triggered by Bot |
|
PR_Github #16342 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
907-909: Python 3.8 compatibility: use typing.Tuple instead of built-in tuple generics.
tuple[int, int]requires Python 3.9+. We target 3.8+, so switch toTuple[int, int](already imported).- def grid(meta: Mapping[str, int]) -> tuple[int, int]: + def grid(meta: Mapping[str, int]) -> Tuple[int, int]: return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(2 hunks)tensorrt_llm/_torch/modules/swiglu.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
tensorrt_llm/_torch/modules/swiglu.py (2)
swiglu(51-60)silu_and_mul_kernel(27-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
895-896: Confirmed single registration oftrtllm::silu_and_mul—no duplicates found.
The search shows exactly one@torch.library.custom_op("trtllm::silu_and_mul", …)and one corresponding@silu_and_mul.register_fake, so there are no conflicting definitions.
|
PR_Github #16344 [ run ] completed with state |
…e copy. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
43ae97e to
f3dbc15
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #16411 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)
2-2: Minor: You import Tuple but still use built-in tuple generics later.After fixing the grid annotation (Lines 997-999), Tuple becomes “used”; otherwise many linters will flag it as unused.
985-1011: Harden silu_and_mul: replace asserts with explicit validation, enforce CUDA + last-dim contiguity, require scalar scale, fix Python 3.8 tuple generics, and add a concise docstring.
- Asserts can be stripped under -O and give vague errors; raise clear exceptions instead.
- The kernel reads rows with pointer arithmetic and assumes x.stride(1) == 1; validate to avoid silent misreads or implicit copies.
- Kernel loads a single o_scale via tl.load(o_scale_ptr), so scale must be scalar (0-dim or 1-element).
- Python 3.8 target: change tuple[int, int] to Tuple[int, int] (Tuple already imported).
- Add a short Google-style docstring per guidelines.
Apply:
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=()) def silu_and_mul(x: torch.Tensor, scale: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - b, n = x.shape - - assert n % 2 == 0 - d = n // 2 - - o_dtype = dtype or x.dtype - o = torch.empty((b, d), dtype=o_dtype, device=x.device) - - def grid(meta: Mapping[str, int]) -> tuple[int, int]: - return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) + """SwiGLU activation backed by a Triton kernel with optional scalar scaling. + + Args: + x: 2D CUDA tensor [B, N]. N must be even. Last dimension must be contiguous. + scale: Optional scalar CUDA tensor (0-dim or 1-element). + dtype: Optional output dtype. Defaults to x.dtype. + + Returns: + Tensor [B, D] where D = N // 2, dtype = dtype or x.dtype. + + Raises: + ValueError: on non-2D input, odd N, or non-contiguous last dim. + RuntimeError: on non-CUDA input or device mismatch for scale. + TypeError: if scale is provided but not scalar. + """ + if x.dim() != 2: + raise ValueError(f"silu_and_mul: expected 2D input [B, N], got {tuple(x.shape)}") + if x.device.type != "cuda": + raise RuntimeError("silu_and_mul: expected CUDA tensor input.") + if x.stride(1) != 1: + raise ValueError("silu_and_mul: expected contiguous last dimension (x.stride(1) == 1).") + b, n = x.shape + if n % 2 != 0: + raise ValueError(f"silu_and_mul: expected even N, got N={n}") + d = n // 2 + if scale is not None: + if scale.device != x.device: + raise RuntimeError("silu_and_mul: `scale` must be on the same device as `x`.") + if scale.numel() != 1: + raise TypeError("silu_and_mul: `scale` must be a scalar (0-dim or 1-element) tensor.") + + o_dtype = dtype or x.dtype + o = torch.empty((b, d), dtype=o_dtype, device=x.device) + + def grid(meta: Mapping[str, int]) -> Tuple[int, int]: + return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) silu_and_mul_kernel[grid]( o_ptr=o, o_stride=o.stride(0), o_scale_ptr=scale, x_ptr=x, x_stride=x.stride(0), d=d, BLOCK_SIZE=1024, HAS_O_SCALE=scale is not None, ) return oTo quickly confirm there are no other Python 3.9+ generic annotations that could break 3.8:
#!/bin/bash # Find Python 3.9+ built-in generic annotations that would break 3.8 rg -nP --type=py '\b(?:list|dict|set|tuple)\s*\[' -g '!**/third_party/**' -C2
997-999: Python 3.8 compatibility: change tuple[int, int] to typing.Tuple[int, int].This file imports Tuple on Line 2 but still uses the 3.9+ form here. This will break 3.8 and may explain recurring pipeline failures. Apply:
- def grid(meta: Mapping[str, int]) -> tuple[int, int]: + def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
1014-1026: Mirror validation in the fake path; avoid bare asserts and provide a clear error.Fake kernels participate in shape/dtype propagation during torch.compile, and early, cheap validation helps catch misuse without launching the real kernel.
Apply:
@silu_and_mul.register_fake def _( x: torch.Tensor, scale: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - b, n = x.shape - - assert n % 2 == 0 - d = n // 2 + if x.dim() != 2: + raise ValueError(f"silu_and_mul(fake): expected 2D input [B, N], got {tuple(x.shape)}") + b, n = x.shape + if n % 2 != 0: + raise ValueError(f"silu_and_mul(fake): expected even N, got N={n}") + d = n // 2 o_dtype = dtype or x.dtype return x.new_empty((b, d), dtype=o_dtype)
🧹 Nitpick comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
15-15: Reduce cross-module coupling: consider relocating the kernel into a dedicated kernels submodule.Importing the kernel from ..modules.swiglu creates a subtle layering dependency between “custom_ops” and “modules”. A kernels subpackage (e.g., tensorrt_llm/_torch/kernels/swiglu.py) would make ownership clearer and avoid future import tangles.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(3 hunks)tensorrt_llm/_torch/modules/swiglu.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
tensorrt_llm/_torch/modules/swiglu.py (2)
swiglu(51-60)silu_and_mul_kernel(27-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
5-5: Importing Triton at module scope looks fine for this file’s usage.Given the op always dispatches a Triton kernel, importing at module scope is acceptable. The type: ignore is also appropriate for environments without stubs.
985-1026: Add minimal unit tests and manually verify torch.compile graph stabilityThe custom op implementation and its
register_fakefallback still lack coverage for shape/device preconditions and compile‐time graph correctness. Since we can’t execute the CI script here (no Torch), please verify locally thattorch.compilearound this op does not introduce unintendedaten.copy_(or other D2D copies). In addition, add lightweight unit tests to guard against regressions:• Precondition checks (should raise errors):
- Odd‐length last dimension (N % 2 ≠ 0)
- Non-2D inputs
- Input on CPU when CUDA is required
- Non-contiguous last-dim tensors
- Non-scalar
scaleargument• Correctness test:
- Compare
torch.ops.trtllm.silu_and_mul(x, scale, dtype)against a reference implementation (x[:, :D].sigmoid() * x[:, D:]) on a small CUDA tensor for various dtypes.• Compile‐path stability:
- Wrap the op in an
nn.Module, runtorch.compile(fullgraph=True), and inspect the FX IR (e.g., viaprint(torch._dynamo.export(...).graph)or similar) to assert noaten.copy_nodes appear in the custom‐op subgraph.Let me know if you need a draft of these tests or guidance on the graph‐inspection commands.
|
PR_Github #16411 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #16526 [ run ] triggered by Bot |
|
PR_Github #16526 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
A redundant D2D copy is observed when enabling torch.compile for the Llama model due to the swiglu triton kernel, which brings perf overhead. Use a custom op to wrap the swiglu op to avoid this overhead.
Summary by CodeRabbit
New Features
Refactor