-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Context Parallel w/ Ring & Ulysses & Unified Attention #11941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Just a few questions and nits.
|
||
|
||
@contextlib.contextmanager | ||
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm wondering if we need this additional context manager. I think the name implies that it's parallelizing the components, when it's really a validation step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this function is what sets the dispatcher to perform context parallel templated attention instead of following the normal attention path in non-CP case
_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config |
Without this, we have to hook into perform the assignment in the pre-forward hook of the model. But that has compatibility issues with torch dynamo (tracing feels with a setattr/getattr-related error. It was our previous approach in this commit, but it had many compatibility issues for distributed training. Instead, explicitly doing this outside the forward is ideal for setting up all the required information about parallelism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but setting parallel_config
in this way leads to this issue no?
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work as always @a-r-r-o-w. Good to merge once tests are passing.
Thanks @DN6! I'm AFK from my personal laptop so can't make changes for another 2-3 days. Sorry about the delay! |
Co-authored-by: Aryan <aryan@huggingface.co>
@DN6 I'm unable to push any changes to this branch since it's on the official repo instead of the my personal fork. I think if you added an entry to |
def _wrapped_flash_attn_3_original( | ||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | ||
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | ||
def _wrapped_flash_attn_3( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will follow suit and make it work with FA3 Hub as well.
Just pushed an entry to the TOC. Cc: @stevhliu for now, I have added it to "Inference optimization" section. But feel free to change that in your documentation PR. @a-r-r-o-w sorry for the delay here. We should be able to merge this now. Thanks for setting the foundations here :) |
#12206) * Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op - Add hasattr() check for torch.library.custom_op and register_fake - These functions were added in PyTorch 2.4, causing import failures in 2.3.1 - Both decorators and functions are now properly guarded with version checks - Maintains backward compatibility while preserving functionality Fixes #12195 * Use dummy decorators approach for PyTorch version compatibility - Replace hasattr check with version string comparison - Add no-op decorator functions for PyTorch < 2.4.0 - Follows pattern from #11941 as suggested by reviewer - Maintains cleaner code structure without indentation changes * Update src/diffusers/models/attention_dispatch.py Update all the decorator usages Co-authored-by: Aryan <contact.aryanvs@gmail.com> * Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan <contact.aryanvs@gmail.com> * Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan <contact.aryanvs@gmail.com> * Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan <contact.aryanvs@gmail.com> * Move version check to top of file and use private naming as requested * Apply style fixes --------- Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Adds support for ring, ulysses and unified attention natively. For a minimal PoC, I've limited changes to Flux.
Supported attention backends with CP: cuDNN, FA2, Sage.
Requires #11916 to be merged first.
Minimal example
Note: the examples here are not up-to-date! Please refer to the official examples once the docs are uodated
Wan
Qwen
LTXVideo
Benchmarks
Flux code
TODO: link to blog post
Explanation
Each model should define a
_cp_plan
attribute that contains information on how to shard/gather tensors at different stages of the forward. Let's try to understand with an example using QwenImage:The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be split/gathered according to this at the respective module level. Here, the following happens:
""
: we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before the actual forward logic of theQwenImageTransformer2DModel
is run, we will split the inputs)"pos_embed"
: we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs), we can individually specify how they should be split"proj_out"
: before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear layer forward has run).ContextParallelInput: specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
ContextParallelOutput: specifies how to gather the input tensor in the post-forward hook in the layer it is attached to