-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Apply AutoTuner to fp8_block_scale_deep_gemm to trigger JIT ahead of time. #7113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughAdds a tunable Torch custom op trtllm::fp8_swap_ab_gemm (runner, tuning-buckets, and fake registration), integrates it into FP8BlockScalesLinearMethod replacing previous per-token quant + deep_gemm path, and updates a unit test to use the new op within an autotune context. Duplicate op definitions present. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as Linear/Test
participant AT as AutoTuner
participant OP as trtllm::fp8_swap_ab_gemm
participant Q as fp8_utils
participant G as deep_gemm.fp8_gemm_nt
Caller->>OP: call fp8_swap_ab_gemm(input, weight, weight_scale, opts)
OP->>AT: gen tuning buckets / request tactic
AT-->>OP: selected tactic (e.g., 0)
rect rgba(220,235,255,0.25)
OP->>Q: per-token quantize input -> A_fp8, A_sf
Q-->>OP: A_fp8, A_sf
end
rect rgba(220,255,220,0.25)
OP->>G: deep_gemm.fp8_gemm_nt(A_fp8, weight, weight_scale, ...)
G-->>OP: output (respected output_dtype)
end
OP-->>Caller: return output
alt Graph/tracing path
note right of OP: register_fake returns placeholder tensor with synthetic shape
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/linear.py (1)
593-598: Switch to fused fp8_swap_ab_gemm looks good; consider surfacing output dtype and verify accuracy parity
- The fused call will help pre-warm kernels via AutoTuner. Two nits:
- Consistency: elsewhere we honor module.dtype or input.dtype for output. Here we rely on the op’s default (bfloat16). If module.dtype is set (e.g., fp16 scenarios), consider passing it through explicitly.
- Semantics: This replaces 1x128 activation quantization with per-token quantization. Please confirm expected accuracy parity on SM100 across the swept shapes.
Suggested change to preserve dtype behavior:
- output = torch.ops.trtllm.fp8_swap_ab_gemm( - input, - module.weight, - module.weight_scale, - disable_ue8m0_cast=True, - ) + output = torch.ops.trtllm.fp8_swap_ab_gemm( + input, + module.weight, + module.weight_scale, + output_dtype=(module.dtype or input.dtype), + disable_ue8m0_cast=True, + )tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (2)
25-26: Good: autotune context added to pre-trigger JITWrapping the call in autotune() enables ahead-of-time tactic selection and compilation; this aligns with the PR goal of reducing inference-time JIT. Consider setting TRTLLM_AUTOTUNE_LOG=1 in CI for visibility of tuned buckets.
54-60: Test now exercises the new fused op; add a dtype assertion and broaden shapes if you want stronger coverageThe check is correct. Two small suggestions:
- Assert output.dtype to guard against accidental default changes.
- Optionally parametrize disable_ue8m0_cast to exercise both branches.
Example adjustments:
with autotune(): output = torch.ops.trtllm.fp8_swap_ab_gemm( a, act_b_fp8, act_b_sf, ) +assert output.dtype == torch.bfloat16Optionally:
- with autotune(): - output = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf) + with autotune(): + output = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf) + with autotune(): + output2 = torch.ops.trtllm.fp8_swap_ab_gemm(a, act_b_fp8, act_b_sf, disable_ue8m0_cast=True) + torch.testing.assert_close(output, output2, atol=2e-2, rtol=2e-2)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
920-926: Class naming is inconsistent with the rest; add runner caching and round rule to tuning_config
- Naming: other runners use PascalCase (FP8RowwiseGemmRunner, FP4GemmRunner, …). Please rename to FP8SwapABGemmRunner for consistency.
- Caching: other runners maintain a runner_dict; add the same to avoid re-instantiation overhead.
- Tuning config: supply the round rule introduced above.
Proposed update:
-class fp8SwapABGemmRunner(TunableRunner): - tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, fp8_swap_ab_gen_tuning_buckets), ), - tune_max_num_tokens=4096, - ) +class FP8SwapABGemmRunner(TunableRunner): + runner_dict = dict() + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, fp8_swap_ab_gen_tuning_buckets, _fp8_swap_ab_round_rule), ), + tune_max_num_tokens=4096, + ) - def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool): - self.output_dtype = output_dtype - self.disable_ue8m0_cast = disable_ue8m0_cast + def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool): + self.output_dtype = output_dtype + self.disable_ue8m0_cast = disable_ue8m0_cast + key = (output_dtype, disable_ue8m0_cast) + if key not in FP8SwapABGemmRunner.runner_dict: + FP8SwapABGemmRunner.runner_dict[key] = selfNote: subsequent references to fp8SwapABGemmRunner need to be updated to FP8SwapABGemmRunner.
962-987: Update to renamed runner and propagate tune_max_num_tokens cleanlyAfter renaming the class and adding round rule, update the references:
- fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner( + fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner( output_dtype, disable_ue8m0_cast, ) - fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens + FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokensAlso consider guarding tune_max_num_tokens to be ≥8 and multiple of 8 for SM100-friendly bucketing.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(2 hunks)tensorrt_llm/_torch/modules/linear.py(1 hunks)tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/custom_ops/torch_custom_ops.pytests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/custom_ops/torch_custom_ops.pytests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py (1)
22-22: Import update is fineUsing per_block_cast_to_fp8_e8m0 for the weight path aligns with DeepGEMM’s expectations. No concerns here.
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
7-8: New imports are appropriateImporting fp8_utils and deep_gemm is expected for the fused FP8 path. No issues spotted.
989-999: Fake kernel shape/dtype is correctThe fake registration returns [M, N] with the requested dtype, matching execution semantics. Looks good.
|
PR_Github #16019 [ run ] triggered by Bot |
|
PR_Github #16019 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #16122 [ run ] triggered by Bot |
…JIT ahead of time. Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
c7d8c80 to
188e2fa
Compare
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4)
895-900: Include upper-bound bucket and add deterministic rounding rule (pre-JIT coverage + determinism).As previously noted, buckets exclude x itself (e.g., x=4096 → max 3968) and there’s no round rule for arbitrary M. This can under-tune at caps and cause non-deterministic selection near small M. Add an upper-bound-inclusive step and wire a floor-to-bucket round rule. Also cache results and add a precise return type.
Apply:
-def fp8_swap_ab_gen_tuning_buckets(x: int): - buckets = tuple(range(8, 128, 8)) - if x >= 128: - buckets += tuple(range(128, x, 128)) - return buckets +@lru_cache(maxsize=None) +def fp8_swap_ab_gen_tuning_buckets(x: int) -> Tuple[int, ...]: + # Base buckets: 8, 16, ..., 120 + buckets = tuple(range(8, 128, 8)) + if x >= 128: + # Include x when exactly divisible by 128 to avoid under-tuning at the cap + buckets += tuple(range(128, x + 1, 128)) + return buckets + +def _fp8_swap_ab_round_rule(x: int) -> int: + # Floor to bucket: + # - x < 128: nearest lower multiple of 8 (min 8) + # - x >= 128: nearest lower multiple of 128 (min 128) + if x < 128: + return max(8, (x // 8) * 8) + return max(128, (x // 128) * 128)
903-907: Wire the round rule into the DynamicTensorSpec.Without a round rule, AutoTuner may map small M unevenly. Use the rule introduced above.
- tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, fp8_swap_ab_gen_tuning_buckets), ), - tune_max_num_tokens=4096, - ) + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, fp8_swap_ab_gen_tuning_buckets, _fp8_swap_ab_round_rule), ), + tune_max_num_tokens=4096, + )
913-920: Make swap_ab truly tunable (align with PR intent).The PR claims swap_ab remains tunable, but only [0] is returned, so it isn’t. Return both tactics.
- # Encode swap_ab as False (0) and True (1). Currently only add one tactic here. - return [0] + # 0 => swap_ab=False, 1 => swap_ab=True + return [0, 1]
921-941: Propagate tactic to deep_gemm and fix typo.Forward ignores tactic and never sets swap_ab; also a minor typo “detemrmine”. Derive swap_ab from tactic and pass into deep_gemm.
output = torch.empty( (input.size(0), weight.size(0)), device=input.device, dtype=self.output_dtype, ) - # TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value - # Treat the default tactic=-1 as swap_ab=False + # Treat the default tactic=-1 as swap_ab=False; tactic==1 => swap_ab=True + swap_ab = (tactic == 1) deep_gemm.fp8_gemm_nt( (a, a_sf), (weight, weight_scale), output, + swap_ab=swap_ab, disable_ue8m0_cast=self.disable_ue8m0_cast, )To confirm the argument is supported across the repo, run:
#!/bin/bash # Verify deep_gemm.fp8_gemm_nt supports 'swap_ab' and check all call sites. set -euo pipefail echo "== Search for fp8_gemm_nt definition or bindings ==" rg -nP -C3 '\bfp8_gemm_nt\b' --type cpp --type c --type h --type hpp --type py echo "== Search for call sites of fp8_gemm_nt ==" rg -nP -C2 'fp8_gemm_nt\s*\('
🧹 Nitpick comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
909-912: Rename class to PascalCase for consistency with guidelines.Python class names should use PascalCase. Rename and update references.
-class fp8SwapABGemmRunner(TunableRunner): +class FP8SwapABGemmRunner(TunableRunner): tuning_config = TuningConfig( ... )Also update references below (outside this hunk):
- fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner( + fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner( output_dtype, disable_ue8m0_cast, ) - fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens + FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
945-969: Minor: keep class name change consistent and set tune_max before choose_one.Use the PascalCase class name consistently; also setting tune_max before choose_one is good—keep it that way.
- fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner( + fp8_swap_ab_gemm_runner = FP8SwapABGemmRunner( output_dtype, disable_ue8m0_cast, ) - fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens + FP8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(2 hunks)tensorrt_llm/_torch/modules/linear.py(1 hunks)tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py
- tensorrt_llm/_torch/modules/linear.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
7-8: Imports look correct and align with downstream usage.
fp8_utilsanddeep_gemmare both used in the new FP8 path; no import leaks or shadowing.
971-981: Fake registration signature and shape inference LGTM.Signature mirrors the real op; output [M, N] with N=weight.size(0) matches fp8_gemm_nt(nt path) expectation.
|
PR_Github #16131 [ run ] triggered by Bot |
|
PR_Github #16122 [ run ] completed with state |
|
PR_Github #16131 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #16138 [ run ] triggered by Bot |
|
PR_Github #16138 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag.
Summary by CodeRabbit
New Features
Performance
Refactor
Tests