KEMBAR78
Context Parallel w/ Ring & Ulysses & Unified Attention by a-r-r-o-w · Pull Request #11941 · huggingface/diffusers · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d7b9e42
update
a-r-r-o-w Jul 14, 2025
7e97e43
update
a-r-r-o-w Jul 14, 2025
ecabd2a
add coauthor
a-r-r-o-w Jul 14, 2025
ff21b7f
improve test
a-r-r-o-w Jul 14, 2025
b8f7fe6
handle ip adapter params correctly
a-r-r-o-w Jul 14, 2025
17b678f
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 15, 2025
0cda91d
fix chroma qkv fusion test
a-r-r-o-w Jul 15, 2025
bc64f12
fix fastercache implementation
a-r-r-o-w Jul 15, 2025
a0b276d
fix more tests
a-r-r-o-w Jul 15, 2025
c141520
fight more tests
a-r-r-o-w Jul 15, 2025
4dcd672
add back set_attention_backend
a-r-r-o-w Jul 15, 2025
576da52
update
a-r-r-o-w Jul 15, 2025
e909b73
update
a-r-r-o-w Jul 15, 2025
1e7217f
make style
a-r-r-o-w Jul 15, 2025
4f52e34
make fix-copies
a-r-r-o-w Jul 15, 2025
d9c1683
make ip adapter processor compatible with attention dispatcher
a-r-r-o-w Jul 15, 2025
a73cb39
refactor chroma as well
a-r-r-o-w Jul 15, 2025
1e6b1c5
remove rmsnorm assert
a-r-r-o-w Jul 16, 2025
251bb61
minify and deprecate npu/xla processors
a-r-r-o-w Jul 16, 2025
84d2c84
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 16, 2025
51fed50
update
a-r-r-o-w Jul 16, 2025
9f37b87
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 16, 2025
7973626
refactor
a-r-r-o-w Jul 16, 2025
f859fdf
refactor; support flash attention 2 with cp
a-r-r-o-w Jul 16, 2025
e76fc94
fix
a-r-r-o-w Jul 16, 2025
171152f
support sage attention with cp
a-r-r-o-w Jul 16, 2025
62f164d
make torch compile compatible
a-r-r-o-w Jul 16, 2025
731b3bb
Merge branch 'to-single-file/flux' into attn-dispatcher-cp-and-training
a-r-r-o-w Jul 16, 2025
ff8ef45
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Jul 17, 2025
26a5a5c
update
a-r-r-o-w Jul 17, 2025
1ffc03e
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Jul 29, 2025
fa5d017
refactor
a-r-r-o-w Jul 30, 2025
215104f
update
a-r-r-o-w Aug 8, 2025
c777184
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Aug 14, 2025
256d5a9
refactor
a-r-r-o-w Aug 14, 2025
27e1d27
refactor
a-r-r-o-w Aug 14, 2025
cca5381
add ulysses backward
a-r-r-o-w Aug 14, 2025
768d0ea
try to make dreambooth script work; accelerator backward not playing …
a-r-r-o-w Aug 14, 2025
0018b62
Revert "try to make dreambooth script work; accelerator backward not …
a-r-r-o-w Aug 14, 2025
2065acc
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Aug 27, 2025
da78c5d
workaround compilation problems with triton when doing all-to-all
a-r-r-o-w Aug 27, 2025
f4c1b4e
support wan
a-r-r-o-w Aug 27, 2025
a820bfd
handle backward correctly
a-r-r-o-w Aug 27, 2025
ae5a707
support qwen
a-r-r-o-w Sep 1, 2025
6bf744c
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Sep 1, 2025
db2efa3
support ltx
a-r-r-o-w Sep 1, 2025
bb07c04
make fix-copies
a-r-r-o-w Sep 1, 2025
2aec312
Update src/diffusers/models/modeling_utils.py
a-r-r-o-w Sep 3, 2025
088d909
apply review suggestions
a-r-r-o-w Sep 4, 2025
f35483a
update docs
a-r-r-o-w Sep 4, 2025
c88fc99
add explanation
a-r-r-o-w Sep 4, 2025
e569785
make fix-copies
a-r-r-o-w Sep 4, 2025
b85c26c
add docstrings
a-r-r-o-w Sep 4, 2025
4b2fcc1
support passing parallel_config to from_pretrained
a-r-r-o-w Sep 4, 2025
bc9fc27
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Sep 4, 2025
498b191
apply review suggestions
a-r-r-o-w Sep 11, 2025
a5d5ef4
make style
a-r-r-o-w Sep 11, 2025
e0c7580
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Sep 11, 2025
2ce098b
update
a-r-r-o-w Sep 11, 2025
d53f7ef
Update docs/source/en/api/parallel.md
DN6 Sep 24, 2025
f4b374c
up
sayakpaul Sep 24, 2025
8884d97
Merge branch 'main' into attn-dispatcher-cp-and-training
sayakpaul Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compiling and offloading quantized models
- local: api/parallel
title: Parallel inference
- title: Community optimizations
sections:
- local: optimization/pruna
Expand Down
24 changes: 24 additions & 0 deletions docs/source/en/api/parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# Parallelism

Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.

## ParallelConfig

[[autodoc]] ParallelConfig

## ContextParallelConfig

[[autodoc]] ContextParallelConfig

[[autodoc]] hooks.apply_context_parallel
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
"CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
"ContextParallelConfig",
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
Expand Down Expand Up @@ -229,6 +230,7 @@
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"QwenImageControlNetModel",
Expand Down Expand Up @@ -888,6 +890,7 @@
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
ContextParallelConfig,
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
Expand Down Expand Up @@ -915,6 +918,7 @@
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageControlNetModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


if is_torch_available():
from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
Expand Down
297 changes: 297 additions & 0 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union

import torch
import torch.distributed._functional_collectives as funcol

from ..models._modeling_parallel import (
ContextParallelConfig,
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook


logger = get_logger(__name__) # pylint: disable=invalid-name

_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"


# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): No need for this PR, but might make sense to introduce help for args like this.

_cls: Type = None

def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}

if identifier in kwargs:
return kwargs[identifier], True, None

if self.cached_parameter_indices is not None:
index = self.cached_parameter_indices.get(identifier, None)
if index is None:
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
return args[index], False, index

if self._cls is None:
raise ValueError("Model class is not set for metadata.")

parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}

if identifier not in self.cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")

index = self.cached_parameter_indices[identifier]

if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")

return args[index], False, index


def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")

for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]

logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")

for m in submodule:
if isinstance(cp_model_plan, dict):
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
if isinstance(cp_model_plan, ContextParallelOutput):
cp_model_plan = [cp_model_plan]
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)


def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]

for m in submodule:
registry = HookRegistry.check_if_exists_or_initialize(m)
if isinstance(cp_model_plan, dict):
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry.remove_hook(hook_name)


class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
self.module_forward_metadata = None

def initialize_hook(self, module):
cls = unwrap_module(module).__class__
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
return module

def pre_forward(self, module, *args, **kwargs):
args_list = list(args)

for name, cpm in self.metadata.items():
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
continue

# Maybe the parameter was passed as a keyword argument
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
name, args_list, kwargs
)

if input_val is None:
continue

# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")

if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
else:
raise ValueError(
f"An unexpected error occurred while processing the input '{name}'. Please open an "
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
f"example along with the full stack trace."
)

return tuple(args_list), kwargs

def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)

if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")

output = [output] if is_tensor else list(output)
for index, cpm in self.metadata.items():
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
continue
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output

return output[0] if is_tensor else tuple(output)

def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config

def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)

if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")

output = list(output)

if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")

for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)

return output[0] if is_tensor else tuple(output)


class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group):
ctx.dim = dim
ctx.group = group
ctx.world_size = torch.distributed.get_world_size(group)
ctx.rank = torch.distributed.get_rank(group)
return funcol.all_gather_tensor(tensor, dim, group=group)

@staticmethod
def backward(ctx, grad_output):
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
return grad_chunks[ctx.rank], None, None


class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
# because the alternate case has not yet been tested/required for any model.
assert tensor.size()[dim] % mesh.size() == 0, (
"Tensor size along dimension to be sharded must be divisible by mesh size"
)

# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]

return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]

@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
return tensor


def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we leverage

def get_submodule_by_name(root_module, module_path: str):

(no worries if not, just wanted to find ways to reduce LoC)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also supporting wildcards here, for example the Wan CP plan:

_cp_plan = {
    "rope": {
        0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
        1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
    },
    "blocks.0": {
        "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
    },
    "blocks.*": {
        "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
    },
    "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}

I think much cleaner to use a custom implementation specific to CP to allow for such cases

if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)


def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
if first_atom == "*":
if not isinstance(model, torch.nn.ModuleList):
raise ValueError("Wildcard '*' can only be used with ModuleList")
submodules = []
for submodule in model:
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
if not isinstance(subsubmodules, list):
subsubmodules = [subsubmodules]
submodules.extend(subsubmodules)
return submodules
else:
if hasattr(model, first_atom):
submodule = getattr(model, first_atom)
return _find_submodule_by_name(submodule, remaining_name)
else:
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_import_structure = {}

if is_torch_available():
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
Expand Down Expand Up @@ -119,6 +120,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
Expand Down
Loading
Loading