-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][fix] Complete the last missing allreduce op in Llama3/4. #6850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughDisable POST_MLP_FUSION and POST_MOE_FUSION for the final decoder layer; change post-fusion gating to consult the next layer's layernorm and next attention/quantization for scale and fusion_op selection; adjust all-reduce/moe_allreduce paths, unpacking, and preserve PRE_MLP_FUSION intent. Changes
Sequence Diagram(s)sequenceDiagram
participant Input as hidden_states
participant Layer as LlamaDecoderLayer
participant NextLN as next_layer_layernorm
participant NextAttn as next_attn
participant Quant as Quant (nvfp4/FP8)
participant Fusion as AllReduce/MOE_AllReduce
Input->>Layer: forward(hidden_states)
Layer->>Layer: determine fusion path (PRE/POST/POST_MOE)
alt POST path and no NextLN (final layer)
Layer->>Fusion: all_reduce(fusion_op=None)
else POST path with NextLN
Layer->>NextAttn: check presence
alt NextAttn present and Quant active
NextAttn->>Layer: provide scale (qkv_proj.input_scale)
Layer->>Fusion: all_reduce(fusion_op=..., scale=provided)
else
Layer->>Fusion: all_reduce(fusion_op=RESIDUAL_RMS_NORM, scale=None)
end
end
Fusion-->>Layer: unpack (Fp4QuantizedTensor or (hidden, residual))
Layer-->>Input: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 💡 Knowledge Base configuration:
You can enable these settings in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/models/modeling_llama.py (2)
710-714: Consider removing or guarding debug prints for productionThe debug prints at lines 711-714 and 765-768 are useful for debugging but should be either removed or properly guarded for production use. Consider using a debug flag or the logging framework instead of direct prints.
- if self.mapping.tp_rank == 0: - print( - f"{self.layer_idx} pre_mlp_fusion_op: {self.pre_mlp_fusion_op}" - ) + if self.mapping.tp_rank == 0 and logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"{self.layer_idx} pre_mlp_fusion_op: {self.pre_mlp_fusion_op}" + )
765-768: Remove or properly guard debug print statementsSimilar to the earlier comment, these debug prints should be properly handled for production code.
- if self.mapping.tp_rank == 0: - print( - f"{self.layer_idx} post_mlp_fusion_op: {self.post_mlp_fusion_op}" - ) + if self.mapping.tp_rank == 0 and logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"{self.layer_idx} post_mlp_fusion_op: {self.post_mlp_fusion_op}" + )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_llama.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/modeling_llama.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/modeling_llama.py
🔇 Additional comments (6)
tensorrt_llm/_torch/models/modeling_llama.py (6)
443-443: Consistent fusion gating for MOE layersGood consistency - the POST_MOE_FUSION also excludes the last layer, matching the MLP fusion behavior. This ensures proper allreduce operation for MOE-based models as well.
554-562: Improved fusion decision logic using next_layer_layernormThe change from checking
next_attntonext_layer_layernormfor determining fusion context is more robust. The logic correctly handles the last decoder layer by setting the fusion operation toRESIDUAL_RMS_NORMwithout scale when there's no next layer.
661-662: Consistent implementation in LlamaDecoderLayerThe POST_MLP_FUSION gating is consistently applied in the regular LlamaDecoderLayer class, maintaining consistency across decoder implementations.
757-763: Consistent fusion context detectionThe change to use
next_layer_layernorminstead of direct next layer checks provides a cleaner and more consistent approach to determine fusion context. The fallback toRESIDUAL_RMS_NORMfor the last layer is appropriate.
778-778: Improved condition for NVFP4 handlingThe condition now correctly checks for
next_attnexistence before checking for NVFP4 quantization, preventing potential attribute errors when next_attn is None on the last layer.
427-427: Please confirm performance and behavior impact of disabling post-MLP fusion on the final decoder layerWe searched the repo and found:
- No test files referencing
POST_MLP_FUSIONorPOST_MOE_FUSIONundertests/- No performance benchmarks for Llama models under
benchmark/- Fusion/allreduce briefly mentioned across various
README.mdfiles but without details on expected behaviorWithout existing tests or benchmarks, please:
- Add or update a unit/integration test to cover the fusion toggle on the last layer
- Run your Llama performance benchmarks to detect any regressions
- Document the change in behavior (e.g., update README or design docs)
|
PR_Github #15054 [ run ] triggered by Bot |
f9c3b18 to
8103f90
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #15058 [ run ] triggered by Bot |
|
PR_Github #15054 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
710-710: Remove unnecessary blank line.There's an extra blank line that can be removed for consistency with the surrounding code style.
- all_reduce_output = self.all_reduce(
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_llama.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/modeling_llama.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/modeling_llama.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/models/modeling_llama.py (7)
427-427: LGTM! Correctly disables POST_MLP_FUSION for the last decoder layer.This change properly prevents post-MLP fusion on the final decoder layer of Llama4, which ensures the missing allreduce operation is properly executed.
443-443: LGTM! Correctly disables POST_MOE_FUSION for the last decoder layer.This change properly prevents post-MOE fusion on the final decoder layer of Llama4 MOE models, ensuring the missing allreduce operation is properly executed.
552-562: Logical improvement: Next-layer context detection via layernorm.The change from
next_attn-based gating tonext_layer_layernormpresence check is a more robust way to detect the existence of a next layer. The logic correctly handles the last layer case by setting the fusion op toRESIDUAL_RMS_NORMwhen no scale is needed.
591-596: Proper unpacking of allreduce output based on quantization mode.The conditional unpacking correctly handles different quantization scenarios - FP4 quantization with next attention layer vs. other cases. This ensures the correct data types are propagated through the model.
661-662: LGTM! Correctly disables POST_MLP_FUSION for the last decoder layer in LlamaModel.This change ensures consistency between Llama3 and Llama4 models by preventing post-MLP fusion on the final decoder layer, fixing the missing allreduce operation issue.
753-760: Logical improvement: Next-layer context detection for standard Llama models.The change properly handles the last layer case by checking for
next_layer_layernormpresence and adjusting the fusion operation and scale accordingly. This ensures consistent behavior across both Llama3 and Llama4 architectures.
770-774: Proper conditional unpacking based on quantization state.The code correctly unpacks the allreduce output based on whether the next attention layer exists and FP4 quantization is enabled, maintaining consistency with the Llama4 implementation.
|
PR_Github #15058 [ run ] completed with state |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #15088 [ run ] triggered by Bot |
|
PR_Github #15088 [ run ] completed with state |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #15195 [ run ] triggered by Bot |
|
PR_Github #15195 [ run ] completed with state |
8103f90 to
8f1c511
Compare
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
8f1c511 to
d32194e
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tensorrt_llm/_torch/models/modeling_llama.py (4)
573-573: Consider making MIN_LATENCY_MODE configurableThe comment indicates that MIN_LATENCY_MODE is hardcoded to False. Consider making this configurable through environment variables or configuration parameters to allow flexibility in deployment scenarios.
719-719: Remove extra blank lineThere's an unnecessary blank line that should be removed for consistency.
Apply this diff:
- all_reduce_output = self.all_reduce(
762-762: Remove trailing whitespaceThere's trailing whitespace on this line that should be removed.
Apply this diff:
- hidden_states, residual) + hidden_states, residual)
85-86: Fix line continuation formattingThe line continuation is broken across two lines in an awkward way. Consider reformatting for better readability.
Apply this diff:
- rope_fusion=not self. - use_qk_norm, # Llama4 uses qk_norm after RoPE, so it is not possible to fuse RoPE into the attention OP with qk_norm. + rope_fusion=not self.use_qk_norm, # Llama4 uses qk_norm after RoPE, so it is not possible to fuse RoPE into the attention OP with qk_norm.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these settings in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_llama.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/modeling_llama.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/modeling_llama.py
🔇 Additional comments (5)
tensorrt_llm/_torch/models/modeling_llama.py (5)
552-562: Ensure correct handling of missing allreduce op on the last decoder layer for Llama4The implementation correctly addresses the missing allreduce operation on the last decoder layer. When
next_layer_layernormisNone(indicating the last layer), it performs a pure allreduce withfusion_op=Noneto disable fusion. This is the appropriate fix for the issue.
564-571: Good implementation of scale adjustment for the last layerThe logic correctly handles the case where
next_layer_layernormexists but could be the last decoder layer. It appropriately adjusts the scale based on the presence ofnext_attnand quantization settings, ensuring proper handling of edge cases.
574-606: Correct unpacking logic for different quantization scenariosThe unpacking logic correctly handles both the min-latency MOE path and the standard allreduce path, with appropriate handling for NVFP4 quantization. The implementation properly differentiates between scenarios based on the presence of
next_attnand quantization modes.
763-771: Good implementation of POST_MLP_FUSION for LlamaDecoderLayerThe implementation correctly mirrors the Llama4DecoderLayer logic for handling the missing allreduce op on the last decoder layer. When
next_layer_layernormisNone, it properly performs a pure allreduce withfusion_op=None.
773-795: Consistent scale adjustment logic across both decoder layer implementationsThe scale adjustment and unpacking logic for LlamaDecoderLayer correctly mirrors the Llama4DecoderLayer implementation, ensuring consistent behavior across both model types when handling the last decoder layer.
|
PR_Github #15266 [ run ] triggered by Bot |
|
PR_Github #15266 [ run ] completed with state |
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
…IA#6850) The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4.
Summary by CodeRabbit
Bug Fixes
Chores