-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Context Parallel w/ Ring & Ulysses & Unified Attention #11941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d7b9e42
7e97e43
ecabd2a
ff21b7f
b8f7fe6
17b678f
0cda91d
bc64f12
a0b276d
c141520
4dcd672
576da52
e909b73
1e7217f
4f52e34
d9c1683
a73cb39
1e6b1c5
251bb61
84d2c84
51fed50
9f37b87
7973626
f859fdf
e76fc94
171152f
62f164d
731b3bb
ff8ef45
26a5a5c
1ffc03e
fa5d017
215104f
c777184
256d5a9
27e1d27
cca5381
768d0ea
0018b62
2065acc
da78c5d
f4c1b4e
a820bfd
ae5a707
6bf744c
db2efa3
bb07c04
2aec312
088d909
f35483a
c88fc99
e569785
b85c26c
4b2fcc1
bc9fc27
498b191
a5d5ef4
e0c7580
2ce098b
d53f7ef
f4b374c
8884d97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| <!-- Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. --> | ||
|
|
||
| # Parallelism | ||
|
|
||
| Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. | ||
|
|
||
| ## ParallelConfig | ||
|
|
||
| [[autodoc]] ParallelConfig | ||
|
|
||
| ## ContextParallelConfig | ||
|
|
||
| [[autodoc]] ContextParallelConfig | ||
|
|
||
| [[autodoc]] hooks.apply_context_parallel |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,297 @@ | ||||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| import inspect | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, List, Type, Union | ||||
|
|
||||
| import torch | ||||
| import torch.distributed._functional_collectives as funcol | ||||
|
|
||||
| from ..models._modeling_parallel import ( | ||||
| ContextParallelConfig, | ||||
| ContextParallelInput, | ||||
| ContextParallelModelPlan, | ||||
| ContextParallelOutput, | ||||
| ) | ||||
| from ..utils import get_logger | ||||
| from ..utils.torch_utils import unwrap_module | ||||
| from .hooks import HookRegistry, ModelHook | ||||
|
|
||||
|
|
||||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||||
|
|
||||
| _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" | ||||
| _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" | ||||
|
|
||||
|
|
||||
| # TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata | ||||
| @dataclass | ||||
| class ModuleForwardMetadata: | ||||
| cached_parameter_indices: Dict[str, int] = None | ||||
| _cls: Type = None | ||||
|
|
||||
| def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): | ||||
| kwargs = kwargs or {} | ||||
|
|
||||
| if identifier in kwargs: | ||||
| return kwargs[identifier], True, None | ||||
|
|
||||
| if self.cached_parameter_indices is not None: | ||||
| index = self.cached_parameter_indices.get(identifier, None) | ||||
| if index is None: | ||||
| raise ValueError(f"Parameter '{identifier}' not found in cached indices.") | ||||
| return args[index], False, index | ||||
|
|
||||
| if self._cls is None: | ||||
| raise ValueError("Model class is not set for metadata.") | ||||
|
|
||||
| parameters = list(inspect.signature(self._cls.forward).parameters.keys()) | ||||
| parameters = parameters[1:] # skip `self` | ||||
| self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} | ||||
|
|
||||
| if identifier not in self.cached_parameter_indices: | ||||
| raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") | ||||
|
|
||||
| index = self.cached_parameter_indices[identifier] | ||||
|
|
||||
| if index >= len(args): | ||||
| raise ValueError(f"Expected {index} arguments but got {len(args)}.") | ||||
|
|
||||
| return args[index], False, index | ||||
|
|
||||
|
|
||||
| def apply_context_parallel( | ||||
| module: torch.nn.Module, | ||||
| parallel_config: ContextParallelConfig, | ||||
| plan: Dict[str, ContextParallelModelPlan], | ||||
| ) -> None: | ||||
| """Apply context parallel on a model.""" | ||||
| logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") | ||||
|
|
||||
| for module_id, cp_model_plan in plan.items(): | ||||
| submodule = _get_submodule_by_name(module, module_id) | ||||
| if not isinstance(submodule, list): | ||||
| submodule = [submodule] | ||||
|
|
||||
| logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") | ||||
|
|
||||
| for m in submodule: | ||||
| if isinstance(cp_model_plan, dict): | ||||
| hook = ContextParallelSplitHook(cp_model_plan, parallel_config) | ||||
| hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) | ||||
| elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): | ||||
| if isinstance(cp_model_plan, ContextParallelOutput): | ||||
| cp_model_plan = [cp_model_plan] | ||||
| if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): | ||||
| raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") | ||||
| hook = ContextParallelGatherHook(cp_model_plan, parallel_config) | ||||
| hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) | ||||
| else: | ||||
| raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") | ||||
| registry = HookRegistry.check_if_exists_or_initialize(m) | ||||
| registry.register_hook(hook, hook_name) | ||||
|
|
||||
|
|
||||
| def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: | ||||
| for module_id, cp_model_plan in plan.items(): | ||||
| submodule = _get_submodule_by_name(module, module_id) | ||||
| if not isinstance(submodule, list): | ||||
| submodule = [submodule] | ||||
|
|
||||
| for m in submodule: | ||||
| registry = HookRegistry.check_if_exists_or_initialize(m) | ||||
| if isinstance(cp_model_plan, dict): | ||||
| hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) | ||||
| elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): | ||||
| hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) | ||||
| else: | ||||
| raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") | ||||
| registry.remove_hook(hook_name) | ||||
|
|
||||
|
|
||||
| class ContextParallelSplitHook(ModelHook): | ||||
| def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: | ||||
| super().__init__() | ||||
| self.metadata = metadata | ||||
| self.parallel_config = parallel_config | ||||
| self.module_forward_metadata = None | ||||
|
|
||||
| def initialize_hook(self, module): | ||||
| cls = unwrap_module(module).__class__ | ||||
| self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) | ||||
| return module | ||||
|
|
||||
| def pre_forward(self, module, *args, **kwargs): | ||||
| args_list = list(args) | ||||
|
|
||||
| for name, cpm in self.metadata.items(): | ||||
| if isinstance(cpm, ContextParallelInput) and cpm.split_output: | ||||
| continue | ||||
|
|
||||
| # Maybe the parameter was passed as a keyword argument | ||||
| input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( | ||||
| name, args_list, kwargs | ||||
| ) | ||||
|
|
||||
| if input_val is None: | ||||
| continue | ||||
|
|
||||
| # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard | ||||
| # the output instead of input for a particular layer by setting split_output=True | ||||
| if isinstance(input_val, torch.Tensor): | ||||
| input_val = self._prepare_cp_input(input_val, cpm) | ||||
| elif isinstance(input_val, (list, tuple)): | ||||
| if len(input_val) != len(cpm): | ||||
| raise ValueError( | ||||
| f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." | ||||
| ) | ||||
| sharded_input_val = [] | ||||
| for i, x in enumerate(input_val): | ||||
| if torch.is_tensor(x) and not cpm[i].split_output: | ||||
| x = self._prepare_cp_input(x, cpm[i]) | ||||
| sharded_input_val.append(x) | ||||
| input_val = sharded_input_val | ||||
| else: | ||||
| raise ValueError(f"Unsupported input type: {type(input_val)}") | ||||
|
|
||||
| if is_kwarg: | ||||
| kwargs[name] = input_val | ||||
| elif index is not None and index < len(args_list): | ||||
| args_list[index] = input_val | ||||
| else: | ||||
| raise ValueError( | ||||
| f"An unexpected error occurred while processing the input '{name}'. Please open an " | ||||
| f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " | ||||
| f"example along with the full stack trace." | ||||
| ) | ||||
|
|
||||
| return tuple(args_list), kwargs | ||||
|
|
||||
| def post_forward(self, module, output): | ||||
| is_tensor = isinstance(output, torch.Tensor) | ||||
| is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) | ||||
|
|
||||
| if not is_tensor and not is_tensor_list: | ||||
| raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | ||||
|
|
||||
| output = [output] if is_tensor else list(output) | ||||
| for index, cpm in self.metadata.items(): | ||||
| if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: | ||||
| continue | ||||
| if index >= len(output): | ||||
| raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") | ||||
| current_output = output[index] | ||||
| current_output = self._prepare_cp_input(current_output, cpm) | ||||
| output[index] = current_output | ||||
|
|
||||
| return output[0] if is_tensor else tuple(output) | ||||
|
|
||||
| def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: | ||||
| if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: | ||||
| raise ValueError( | ||||
| f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." | ||||
| ) | ||||
| return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) | ||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
|
|
||||
| class ContextParallelGatherHook(ModelHook): | ||||
| def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: | ||||
| super().__init__() | ||||
| self.metadata = metadata | ||||
| self.parallel_config = parallel_config | ||||
|
|
||||
| def post_forward(self, module, output): | ||||
| is_tensor = isinstance(output, torch.Tensor) | ||||
|
|
||||
| if is_tensor: | ||||
| output = [output] | ||||
| elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): | ||||
| raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | ||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
| output = list(output) | ||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
| if len(output) != len(self.metadata): | ||||
| raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") | ||||
|
|
||||
| for i, cpm in enumerate(self.metadata): | ||||
| if cpm is None: | ||||
| continue | ||||
| output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) | ||||
|
|
||||
| return output[0] if is_tensor else tuple(output) | ||||
|
|
||||
|
|
||||
| class AllGatherFunction(torch.autograd.Function): | ||||
| @staticmethod | ||||
| def forward(ctx, tensor, dim, group): | ||||
| ctx.dim = dim | ||||
| ctx.group = group | ||||
| ctx.world_size = torch.distributed.get_world_size(group) | ||||
| ctx.rank = torch.distributed.get_rank(group) | ||||
| return funcol.all_gather_tensor(tensor, dim, group=group) | ||||
|
|
||||
| @staticmethod | ||||
| def backward(ctx, grad_output): | ||||
| grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim) | ||||
| return grad_chunks[ctx.rank], None, None | ||||
|
|
||||
|
|
||||
| class EquipartitionSharder: | ||||
| @classmethod | ||||
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | ||||
| # NOTE: the following assertion does not have to be true in general. We simply enforce it for now | ||||
| # because the alternate case has not yet been tested/required for any model. | ||||
| assert tensor.size()[dim] % mesh.size() == 0, ( | ||||
| "Tensor size along dimension to be sharded must be divisible by mesh size" | ||||
| ) | ||||
|
|
||||
| # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) | ||||
| # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] | ||||
|
|
||||
| return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] | ||||
|
|
||||
| @classmethod | ||||
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | ||||
| tensor = tensor.contiguous() | ||||
| tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) | ||||
| return tensor | ||||
|
|
||||
|
|
||||
| def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we leverage
(no worries if not, just wanted to find ways to reduce LoC) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also supporting wildcards here, for example the Wan CP plan: _cp_plan = {
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}I think much cleaner to use a custom implementation specific to CP to allow for such cases |
||||
| if name.count("*") > 1: | ||||
| raise ValueError("Wildcard '*' can only be used once in the name") | ||||
| return _find_submodule_by_name(model, name) | ||||
|
|
||||
|
|
||||
| def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: | ||||
| if name == "": | ||||
| return model | ||||
| first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") | ||||
| if first_atom == "*": | ||||
| if not isinstance(model, torch.nn.ModuleList): | ||||
| raise ValueError("Wildcard '*' can only be used with ModuleList") | ||||
| submodules = [] | ||||
| for submodule in model: | ||||
| subsubmodules = _find_submodule_by_name(submodule, remaining_name) | ||||
| if not isinstance(subsubmodules, list): | ||||
| subsubmodules = [subsubmodules] | ||||
| submodules.extend(subsubmodules) | ||||
| return submodules | ||||
| else: | ||||
| if hasattr(model, first_atom): | ||||
| submodule = getattr(model, first_atom) | ||||
| return _find_submodule_by_name(submodule, remaining_name) | ||||
| else: | ||||
| raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): No need for this PR, but might make sense to introduce
helpfor args like this.