KEMBAR78
[None][chore] Wrap the swiglu into custom op to avoid redundant device copy. by hyukn · Pull Request #7021 · NVIDIA/TensorRT-LLM · GitHub
Skip to content

Conversation

@hyukn
Copy link
Collaborator

@hyukn hyukn commented Aug 19, 2025

A redundant D2D copy is observed when enabling torch.compile for the Llama model due to the swiglu triton kernel, which brings perf overhead. Use a custom op to wrap the swiglu op to avoid this overhead.

Summary by CodeRabbit

  • New Features

    • Adds a public SiLU-and-multiply operation with optional scaling for direct use in models and pipelines; input width must be even.
    • Provides a fake-path compatibility fallback that returns a placeholder result when the native extension isn't available.
  • Refactor

    • SwiGLU now delegates to the new operation for consistent behavior across environments, simplifying the Python-level implementation.

@hyukn hyukn requested review from a team as code owners August 19, 2025 05:57
@hyukn hyukn requested a review from HuiGao-NV August 19, 2025 05:57
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 19, 2025

📝 Walkthrough

Walkthrough

Adds a new Torch custom op trtllm::silu_and_mul with a fake-path variant implemented to launch a Triton kernel; updates SwiGLU to call torch.ops.trtllm.silu_and_mul and removes the Python Triton wrapper while leaving the kernel definition in swiglu.py.

Changes

Cohort / File(s) Summary
Custom op implementation
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Adds @torch.library.custom_op("trtllm::silu_and_mul") implementation and a .register_fake fake-path. Op asserts input width N is even, computes output shape (B, D) with D = N/2, supports optional scale, builds a grid hook using Mapping & triton.cdiv, and launches silu_and_mul_kernel with pointers/strides and a HAS_O_SCALE flag.
SwiGLU wrapper update
tensorrt_llm/_torch/modules/swiglu.py
Removes the Python-level silu_and_mul wrapper and instead calls torch.ops.trtllm.silu_and_mul (forwards scale/dtype if provided). Cleans up unused imports. The Triton kernel definition remains in the module but is not invoked by Python.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant SwiGLU as SwiGLU.forward
  participant TorchOp as torch.ops.trtllm.silu_and_mul
  participant Triton as silu_and_mul_kernel (Triton)

  SwiGLU->>TorchOp: silu_and_mul(x[, scale, dtype])
  alt scale provided
    Note right of TorchOp #DDEBF7: HAS_O_SCALE = True
  else no scale
    Note right of TorchOp #F7F0DD: HAS_O_SCALE = False
  end
  TorchOp->>Triton: launch grid (B, triton.cdiv(D, BLOCK_SIZE)) with pointers/strides and flags
  Triton-->>TorchOp: returns output tensor o (B, N/2)
  TorchOp-->>SwiGLU: o
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • litaotju

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 19, 2025

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/swiglu.py (1)

54-58: All existing swiglu calls use torch.dtype – optional guard for Triton dtypes
All current invocations of swiglu(x, quant_scale, quant_type) pass a torch.dtype (or None), so no runtime dtype mismatches will occur. If you’d like to future-proof against someone accidentally passing a Triton dtype (e.g. tl.float8e4nv or tl.float8e5), you can optionally add:

• File: tensorrt_llm/_torch/modules/swiglu.py (around lines 51–58)
• Wrap the custom-op call’s dtype argument in a small converter:

@@ def swiglu(x, quant_scale: torch.Tensor = None, quant_type=None):
-        return torch.ops.trtllm.silu_and_mul(
-            x,
-            scale=quant_scale,
-            dtype=quant_type,
-        )
+        return torch.ops.trtllm.silu_and_mul(
+            x,
+            scale=quant_scale,
+            dtype=_as_torch_dtype(quant_type),
+        )

And add this helper at the bottom of the module:

def _as_torch_dtype(dt):
    import triton.language as tl, torch
    if dt is None or isinstance(dt, torch.dtype):
        return dt
    # Map Triton FP8→PyTorch FP8
    if dt is getattr(tl, "float8e4nv", None):
        return getattr(torch, "float8_e4m3fn", torch.float16)
    if dt is getattr(tl, "float8e5", None):
        return getattr(torch, "float8_e5m2", torch.bfloat16)
    return dt

No changes are required for correctness today—this is purely an optional refactor.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (3)

913-940: Handle non-contiguous inputs or enforce contiguity explicitly.

The Triton kernel indexes the last dimension as if stride(1) == 1. If x is a non-contiguous view (e.g., sliced/transposed), the loads will read wrong elements. Either assert contiguity or make a safe fallback.

Apply this guard to avoid silent misreads (rare path does a copy; common path stays zero-copy):

 def silu_and_mul(x: torch.Tensor,
                  scale: Optional[torch.Tensor] = None,
                  dtype: Optional[torch.dtype] = None) -> torch.Tensor:
     b, n = x.shape

     assert n % 2 == 0
     d = n // 2

+    # Kernel assumes contiguous last-dimension indexing.
+    if x.stride(1) != 1:
+        x = x.contiguous()
+
     o_dtype = dtype or x.dtype
     o = torch.empty((b, d), dtype=o_dtype, device=x.device)

If you prefer to avoid the copy entirely, make it a hard check:

if x.stride(1) != 1:
    raise ValueError("trtllm::silu_and_mul expects x to be contiguous along the last dimension")

919-921: Prefer a ValueError over assert for shape checks.

Using assert for input validation can be stripped with Python optimizations, leading to hard-to-debug failures downstream. Raise an explicit ValueError instead.

-    assert n % 2 == 0
+    if (n % 2) != 0:
+        raise ValueError(f"trtllm::silu_and_mul expects even last dim; got n={n}")

913-955: File header missing: add NVIDIA copyright header.

Our guidelines require the NVIDIA copyright header (current year) on all source files.

Add at the top of the file:

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between daa2a65 and eb27b6d.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/modules/swiglu.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/modules/swiglu.py
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/modules/swiglu.py
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/modules/swiglu.py (1)

60-60: LGTM: delegating to the custom op in the non-quantized path.

The zero-arg variant cleanly calls the fused op without scale/dtype and avoids redundant device copies.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

13-13: Kernel import location looks good.

Importing silu_and_mul_kernel from modules.swiglu maintains a single source of truth for the kernel while exposing a clean custom op API here.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15725 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15725 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11819 completed with status: 'FAILURE'

@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from eb27b6d to ffb3f9e Compare August 20, 2025 01:39
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 20, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15839 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15839 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11905 completed with status: 'FAILURE'

@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from ffb3f9e to 0d0c6b4 Compare August 21, 2025 01:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

704-714: Bug: variable name mismatch in w4a8_mxfp4_fp8_gemm.register_fake (act_fp8 vs act_fp4).

The function parameters are act_fp4 but the return uses act_fp8, which will raise NameError.

 @w4a8_mxfp4_fp8_gemm.register_fake
 def _(
-    act_fp4: torch.Tensor,
+    act_fp8: torch.Tensor,
     weight: torch.Tensor,
     act_sf: torch.Tensor,
     weight_scale: torch.Tensor,
     alpha: torch.Tensor,
     output_dtype: torch.dtype,
     to_userbuffers: bool = False,
 ) -> torch.Tensor:
-    return act_fp8.new_empty((act_fp8.size(0), weight.size(0)),
+    return act_fp8.new_empty((act_fp8.size(0), weight.size(0)),
                              dtype=output_dtype)

If the intent was act_fp4, then consistently rename usages instead:

-    return act_fp8.new_empty((act_fp8.size(0), weight.size(0)),
+    return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
                              dtype=output_dtype)
♻️ Duplicate comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

925-927: Python 3.8 compatibility: use typing.Tuple instead of built-in tuple generics.

This is a SyntaxError on 3.8.

-    def grid(meta: Mapping[str, int]) -> tuple[int, int]:
+    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
🧹 Nitpick comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)

1-1: Add NVIDIA copyright header (2025) per repo guidelines.

The file is missing the required NVIDIA copyright header.

Apply at the very top:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

5-5: Optional: drop top-level Triton import by avoiding triton.cdiv in grid.

This keeps import-time dependencies lighter (helpful for fake/meta paths), and the integer expression is equivalent.

-import triton  # type: ignore[import]
@@
-    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
-        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
+    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
+        # ceil_div(d, BLOCK_SIZE)
+        return (b, (d + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"])

Also applies to: 925-927


13-13: Prefer module-namespace import to avoid symbol-level coupling.

Keeps imports consistent with the guideline to maintain module namespaces, and reduces risk of future name collisions.

-from ..modules.swiglu import silu_and_mul_kernel
+from ..modules import swiglu as _swiglu
@@
-    silu_and_mul_kernel[grid](
+    _swiglu.silu_and_mul_kernel[grid](
         o_ptr=o,
         o_stride=o.stride(0),
         o_scale_ptr=scale,
         x_ptr=x,
         x_stride=x.stride(0),
         d=d,
         BLOCK_SIZE=1024,
         HAS_O_SCALE=scale is not None,
     )

Also applies to: 928-938


942-955: Meta (fake) impl: align with real op’s validation and docstring.

Mirror minimal validation and add a short docstring. This helps shape errors surface consistently under torch.compile.

 @silu_and_mul.register_fake
 def _(
     x: torch.Tensor,
     scale: Optional[torch.Tensor] = None,
     dtype: Optional[torch.dtype] = None,
 ) -> torch.Tensor:
-    b, n = x.shape
-
-    assert n % 2 == 0
-    d = n // 2
+    """Meta for silu_and_mul: returns an empty tensor with correct shape/dtype."""
+    if x.dim() != 2:
+        raise ValueError(f"silu_and_mul (fake): expected 2D input [B, N], got {tuple(x.shape)}")
+    b, n = x.shape
+    if n % 2 != 0:
+        raise ValueError(f"silu_and_mul (fake): expected even N, got N={n}")
+    d = n // 2
     o_dtype = dtype or x.dtype
     return x.new_empty((b, d), dtype=o_dtype)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ffb3f9e and 0d0c6b4.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/modules/swiglu.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

2-2: Typing imports look good.

Importing Mapping/Optional/Tuple prepares for Python 3.8 compatibility fixes below.


913-955: The script will list the imports and Python implementation of flashinfer_silu_and_mul. This will clarify whether it’s still calling an older, non-ops-based silu_and_mul or correctly dispatching to the custom‐op. Once we see that, we can determine if the review comment is resolved or if further changes are needed.

@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from 0d0c6b4 to e5e2c62 Compare August 21, 2025 03:39
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 21, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15987 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15987 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12016 completed with status: 'FAILURE'

@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from e5e2c62 to 5b4468d Compare August 22, 2025 04:37
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 22, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16121 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16121 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12126 completed with status: 'FAILURE'

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 22, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16159 [ run ] triggered by Bot

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 23, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16235 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16235 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12204 completed with status: 'FAILURE'

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 25, 2025

/bot run --disable-fail-fast

@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from 5b4468d to 43ae97e Compare August 25, 2025 01:47
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16342 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16344 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16342 [ run ] completed with state ABORTED

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

907-909: Python 3.8 compatibility: use typing.Tuple instead of built-in tuple generics.

tuple[int, int] requires Python 3.9+. We target 3.8+, so switch to Tuple[int, int] (already imported).

-    def grid(meta: Mapping[str, int]) -> tuple[int, int]:
+    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
         return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5b4468d and 43ae97e.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/modules/swiglu.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
tensorrt_llm/_torch/modules/swiglu.py (2)
  • swiglu (51-60)
  • silu_and_mul_kernel (27-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

895-896: Confirmed single registration of trtllm::silu_and_mul—no duplicates found.
The search shows exactly one @torch.library.custom_op("trtllm::silu_and_mul", …) and one corresponding @silu_and_mul.register_fake, so there are no conflicting definitions.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16344 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12285 completed with status: 'FAILURE'

…e copy.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
@hyukn hyukn force-pushed the chore/torch_compile_swiglu branch from 43ae97e to f3dbc15 Compare August 25, 2025 07:32
@hyukn
Copy link
Collaborator Author

hyukn commented Aug 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16411 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)

2-2: Minor: You import Tuple but still use built-in tuple generics later.

After fixing the grid annotation (Lines 997-999), Tuple becomes “used”; otherwise many linters will flag it as unused.


985-1011: Harden silu_and_mul: replace asserts with explicit validation, enforce CUDA + last-dim contiguity, require scalar scale, fix Python 3.8 tuple generics, and add a concise docstring.

  • Asserts can be stripped under -O and give vague errors; raise clear exceptions instead.
  • The kernel reads rows with pointer arithmetic and assumes x.stride(1) == 1; validate to avoid silent misreads or implicit copies.
  • Kernel loads a single o_scale via tl.load(o_scale_ptr), so scale must be scalar (0-dim or 1-element).
  • Python 3.8 target: change tuple[int, int] to Tuple[int, int] (Tuple already imported).
  • Add a short Google-style docstring per guidelines.

Apply:

 @torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
 def silu_and_mul(x: torch.Tensor,
                  scale: Optional[torch.Tensor] = None,
                  dtype: Optional[torch.dtype] = None) -> torch.Tensor:
-    b, n = x.shape
-
-    assert n % 2 == 0
-    d = n // 2
-
-    o_dtype = dtype or x.dtype
-    o = torch.empty((b, d), dtype=o_dtype, device=x.device)
-
-    def grid(meta: Mapping[str, int]) -> tuple[int, int]:
-        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
+    """SwiGLU activation backed by a Triton kernel with optional scalar scaling.
+
+    Args:
+        x: 2D CUDA tensor [B, N]. N must be even. Last dimension must be contiguous.
+        scale: Optional scalar CUDA tensor (0-dim or 1-element).
+        dtype: Optional output dtype. Defaults to x.dtype.
+
+    Returns:
+        Tensor [B, D] where D = N // 2, dtype = dtype or x.dtype.
+
+    Raises:
+        ValueError: on non-2D input, odd N, or non-contiguous last dim.
+        RuntimeError: on non-CUDA input or device mismatch for scale.
+        TypeError: if scale is provided but not scalar.
+    """
+    if x.dim() != 2:
+        raise ValueError(f"silu_and_mul: expected 2D input [B, N], got {tuple(x.shape)}")
+    if x.device.type != "cuda":
+        raise RuntimeError("silu_and_mul: expected CUDA tensor input.")
+    if x.stride(1) != 1:
+        raise ValueError("silu_and_mul: expected contiguous last dimension (x.stride(1) == 1).")
+    b, n = x.shape
+    if n % 2 != 0:
+        raise ValueError(f"silu_and_mul: expected even N, got N={n}")
+    d = n // 2
+    if scale is not None:
+        if scale.device != x.device:
+            raise RuntimeError("silu_and_mul: `scale` must be on the same device as `x`.")
+        if scale.numel() != 1:
+            raise TypeError("silu_and_mul: `scale` must be a scalar (0-dim or 1-element) tensor.")
+
+    o_dtype = dtype or x.dtype
+    o = torch.empty((b, d), dtype=o_dtype, device=x.device)
+
+    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:
+        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
 
     silu_and_mul_kernel[grid](
         o_ptr=o,
         o_stride=o.stride(0),
         o_scale_ptr=scale,
         x_ptr=x,
         x_stride=x.stride(0),
         d=d,
         BLOCK_SIZE=1024,
         HAS_O_SCALE=scale is not None,
     )
 
     return o

To quickly confirm there are no other Python 3.9+ generic annotations that could break 3.8:

#!/bin/bash
# Find Python 3.9+ built-in generic annotations that would break 3.8
rg -nP --type=py '\b(?:list|dict|set|tuple)\s*\[' -g '!**/third_party/**' -C2

997-999: Python 3.8 compatibility: change tuple[int, int] to typing.Tuple[int, int].

This file imports Tuple on Line 2 but still uses the 3.9+ form here. This will break 3.8 and may explain recurring pipeline failures. Apply:

-    def grid(meta: Mapping[str, int]) -> tuple[int, int]:
+    def grid(meta: Mapping[str, int]) -> Tuple[int, int]:

1014-1026: Mirror validation in the fake path; avoid bare asserts and provide a clear error.

Fake kernels participate in shape/dtype propagation during torch.compile, and early, cheap validation helps catch misuse without launching the real kernel.

Apply:

 @silu_and_mul.register_fake
 def _(
     x: torch.Tensor,
     scale: Optional[torch.Tensor] = None,
     dtype: Optional[torch.dtype] = None,
 ) -> torch.Tensor:
-    b, n = x.shape
-
-    assert n % 2 == 0
-    d = n // 2
+    if x.dim() != 2:
+        raise ValueError(f"silu_and_mul(fake): expected 2D input [B, N], got {tuple(x.shape)}")
+    b, n = x.shape
+    if n % 2 != 0:
+        raise ValueError(f"silu_and_mul(fake): expected even N, got N={n}")
+    d = n // 2
 
     o_dtype = dtype or x.dtype
     return x.new_empty((b, d), dtype=o_dtype)
🧹 Nitpick comments (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)

15-15: Reduce cross-module coupling: consider relocating the kernel into a dedicated kernels submodule.

Importing the kernel from ..modules.swiglu creates a subtle layering dependency between “custom_ops” and “modules”. A kernels subpackage (e.g., tensorrt_llm/_torch/kernels/swiglu.py) would make ownership clearer and avoid future import tangles.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 43ae97e and f3dbc15.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (3 hunks)
  • tensorrt_llm/_torch/modules/swiglu.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/modules/swiglu.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
tensorrt_llm/_torch/modules/swiglu.py (2)
  • swiglu (51-60)
  • silu_and_mul_kernel (27-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

5-5: Importing Triton at module scope looks fine for this file’s usage.

Given the op always dispatches a Triton kernel, importing at module scope is acceptable. The type: ignore is also appropriate for environments without stubs.


985-1026: Add minimal unit tests and manually verify torch.compile graph stability

The custom op implementation and its register_fake fallback still lack coverage for shape/device preconditions and compile‐time graph correctness. Since we can’t execute the CI script here (no Torch), please verify locally that torch.compile around this op does not introduce unintended aten.copy_ (or other D2D copies). In addition, add lightweight unit tests to guard against regressions:

• Precondition checks (should raise errors):

  • Odd‐length last dimension (N % 2 ≠ 0)
  • Non-2D inputs
  • Input on CPU when CUDA is required
  • Non-contiguous last-dim tensors
  • Non-scalar scale argument

• Correctness test:

  • Compare torch.ops.trtllm.silu_and_mul(x, scale, dtype) against a reference implementation (x[:, :D].sigmoid() * x[:, D:]) on a small CUDA tensor for various dtypes.

• Compile‐path stability:

  • Wrap the op in an nn.Module, run torch.compile(fullgraph=True), and inspect the FX IR (e.g., via print(torch._dynamo.export(...).graph) or similar) to assert no aten.copy_ nodes appear in the custom‐op subgraph.

Let me know if you need a draft of these tests or guidance on the graph‐inspection commands.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16411 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12335 completed with status: 'FAILURE'

@hyukn
Copy link
Collaborator Author

hyukn commented Aug 26, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16526 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16526 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12410 completed with status: 'SUCCESS'

Copy link
Collaborator

@HuiGao-NV HuiGao-NV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hyukn hyukn merged commit bed5bc9 into NVIDIA:main Aug 27, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants