KEMBAR78
Context Parallel w/ Ring & Ulysses & Unified Attention by a-r-r-o-w · Pull Request #11941 · huggingface/diffusers · GitHub
Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Jul 16, 2025

Adds support for ring, ulysses and unified attention natively. For a minimal PoC, I've limited changes to Flux.

Supported attention backends with CP: cuDNN, FA2, Sage.

Requires #11916 to be merged first.

Minimal example

Note: the examples here are not up-to-date! Please refer to the official examples once the docs are uodated

import torch
from diffusers import FluxPipeline
from diffusers import ParallelConfig, enable_parallelism

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)

    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to(device)
    pipe.transformer.parallelize(config=ParallelConfig(ulysses_degree=2))
    # pipe.transformer.set_attention_backend("_native_cudnn")
    pipe.transformer.set_attention_backend("flash")
    # pipe.transformer.compile(fullgraph=True)
    # pipe.transformer.compile(fullgraph=True, mode="max-autotune")

    prompt = "A cat holding a sign that says 'hello world'"
    
    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    with enable_parallelism(pipe):
        image = pipe(prompt, num_inference_steps=2, guidance_scale=4.0, generator=generator).images[0]
        image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0, generator=generator).images[0]
    
    if rank == 0:
        image.save("output.png")

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
Wan
import torch
from diffusers import AutoencoderKLWan, WanPipeline, ParallelConfig, enable_parallelism
from diffusers.utils import export_to_video
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)
    # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
    flow_shift = 5.0  # 5.0 for 720P, 3.0 for 480P
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
    pipe.to("cuda")

    pipe.transformer.parallelize(config=ParallelConfig(ulysses_degree=2))
    pipe.transformer.set_attention_backend("flash")

    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    with enable_parallelism(pipe):
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=480,
            width=832,
            num_frames=81,
            guidance_scale=5.0,
            generator=generator,
        ).frames[0]
    
    if rank == 0:
        export_to_video(output, "output.mp4", fps=16)

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
Qwen
import torch
from diffusers import QwenImagePipeline, ParallelConfig, enable_parallelism


try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)
    
    pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.transformer.parallelize(config=ParallelConfig(ulysses_degree=2))
    pipe.transformer.set_attention_backend("flash")
    
    prompt = "A cat holding a sign that says 'hello world'"
    
    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    with enable_parallelism(pipe):
        image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
    
    if rank == 0:
        image.save("output.png")

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
LTXVideo
import torch
from diffusers import LTXPipeline, ParallelConfig, enable_parallelism
from diffusers.utils import export_to_video

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)

    pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.transformer.parallelize(config=ParallelConfig(ulysses_degree=4))
    pipe.transformer.set_attention_backend("flash")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    with enable_parallelism(pipe):
        video = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=704,
            height=512,
            num_frames=161,
            num_inference_steps=50,
            generator=generator,
        ).frames[0]

    if rank == 0:
        export_to_video(video, "output.mp4", fps=24)

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

Benchmarks

Flux code
import argparse
import contextlib
import functools
import pathlib
import math
from dataclasses import dataclass
from typing import Callable, List, Literal, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.profiler._utils
import torch._dynamo.config
import torch._inductor.config
import torch._higher_order_ops.auto_functionalize as af
from torch.profiler import profile, record_function, ProfilerActivity

from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import get_1d_rotary_pos_embed
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import (
    CombinedTimestepGuidanceTextProjEmbeddings,
    CombinedTimestepTextProjEmbeddings,
)
from diffusers.models.modeling_utils import ModelMixin
from kernels import get_kernel
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

try:
    from flash_attn import flash_attn_func
except:
    print("Flash Attention 2 not found.")

try:
    from flash_attn_interface import flash_attn_func as flash_attn_3_func
except:
    print("Flash Attention 3 not found.")


def apply_flags():
    torch._dynamo.config.inline_inbuilt_nn_modules = False
    torch._dynamo.config.cache_size_limit = 128
    torch._dynamo.config.error_on_recompile = True

    torch._inductor.config.conv_1x1_as_mm = True
    torch._inductor.config.coordinate_descent_check_all_directions = True
    torch._inductor.config.coordinate_descent_tuning = True
    torch._inductor.config.disable_progress = False
    torch._inductor.config.fx_graph_cache = True
    torch._inductor.config.epilogue_fusion = False
    torch._inductor.config.aggressive_fusion = True
    torch._inductor.config.shape_padding = True
    torch._inductor.config.triton.enable_persistent_tma_matmul = True

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
    torch.backends.cudnn.allow_tf32 = True

    af.auto_functionalized_v2._cacheable = True
    af.auto_functionalized._cacheable = True


ROPE_PRECISION = torch.bfloat16
SPATIAL_COMPRESSION_RATIO = 8
PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR = 2
T5_SEQUENCE_LENGTH = 512
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
PATCH_SIZE = 1
LATENT_HEIGHT = DEFAULT_HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
LATENT_WIDTH = DEFAULT_WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
SUPPORTED_BUCKET_LENGTHS = list(range(128, 512 + 1, 64))
SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)]  # 0, 0.5, 1.0, ..., 20.0
MIN_INFERENCE_STEPS = 2
MAX_INFERENCE_STEPS = 50
ATTENTION_OP = None

BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN)
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN


@dataclass
class ContextParallelOptions:
    ring_degree: int = None
    ulysses_degree: int = None
    mode: Literal["ring", "ulysses", "unified"] = "ring"
    mesh: dist.DeviceMesh | None = None
    convert_to_fp32: bool = True
    attention_op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | None = None

    _flattened_mesh: dist.DeviceMesh = None
    _ring_mesh: dist.DeviceMesh = None
    _ulysses_mesh: dist.DeviceMesh = None
    _ring_local_rank: int = None
    _ulysses_local_rank: int = None


cp_options = ContextParallelOptions()


class AdaLayerNormContinuous(torch.nn.Module):
    def __init__(
        self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True
    ):
        super().__init__()

        self.linear = torch.nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
        self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)

    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
        emb = self.linear(emb)
        scale, shift = emb.unsqueeze(1).chunk(2, dim=-1)
        x = self.norm(x)
        x = torch.addcmul(shift, x, 1 + scale)
        return x


class AdaLayerNormZeroSingle(torch.nn.Module):
    def __init__(self, embedding_dim: int, bias=True):
        super().__init__()

        self.linear = torch.nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
        self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x: torch.Tensor, emb: torch.Tensor):
        shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
        x = self.norm(x)
        x = torch.addcmul(shift_msa, x, 1 + scale_msa)
        return x, gate_msa


class AdaLayerNormZero(torch.nn.Module):
    def __init__(self, embedding_dim: int, bias=True):
        super().__init__()

        self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
        self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x: torch.Tensor, emb: torch.Tensor):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
        x = self.norm(x)
        x = torch.addcmul(shift_msa, x, 1 + scale_msa)
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class Attention(torch.nn.Module):
    def __init__(
        self,
        query_dim: int,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        qk_norm: Optional[str] = None,
        added_kv_proj_dim: Optional[int] = None,
        added_proj_bias: Optional[bool] = True,
        out_bias: bool = True,
        eps: float = 1e-5,
        out_dim: int = None,
        context_pre_only=None,
        pre_only=False,
        elementwise_affine: bool = True,
    ):
        super().__init__()

        assert qk_norm == "rms_norm", "Flux uses RMSNorm"

        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
        self.query_dim = query_dim
        self.use_bias = bias
        self.dropout = dropout
        self.out_dim = out_dim if out_dim is not None else query_dim
        self.context_pre_only = context_pre_only
        self.pre_only = pre_only
        self.heads = out_dim // dim_head if out_dim is not None else heads
        self.added_proj_bias = added_proj_bias

        self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
        self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
        self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
        self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
        self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)

        if not self.pre_only:
            self.to_out = torch.nn.ModuleList([])
            self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))

        if added_kv_proj_dim is not None:
            self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
            self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
            self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
            self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
            self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
            self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None)

        if self.fused_projections:
            query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
        else:
            query = self.to_q(hidden_states)
            key = self.to_k(hidden_states)
            value = self.to_v(hidden_states)

        query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value))
        query = self.norm_q(query)
        key = self.norm_k(key)

        if encoder_hidden_states is not None:
            if self.fused_projections:
                query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
            else:
                query_c = self.add_q_proj(encoder_hidden_states)
                key_c = self.add_k_proj(encoder_hidden_states)
                value_c = self.add_v_proj(encoder_hidden_states)

            query_c, key_c, value_c = (
                x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c)
            )
            query_c = self.norm_added_q(query_c)
            key_c = self.norm_added_k(key_c)

            query = torch.cat([query_c, query], dim=1)
            key = torch.cat([key_c, key], dim=1)
            value = torch.cat([value_c, value], dim=1)

        if image_rotary_emb is not None:
            x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1)
            x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
            query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query)

            x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1)
            x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
            key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key)

        hidden_states, lse = ATTENTION_OP(query, key, value)
        hidden_states = hidden_states.flatten(2)

        if encoder_hidden_states is not None:
            encoder_hidden_states, hidden_states = torch.split_with_sizes(
                hidden_states,
                [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]],
                dim=1,
            )
            hidden_states = self.to_out[0](hidden_states)
            encoder_hidden_states = self.to_add_out(encoder_hidden_states)
            return hidden_states, encoder_hidden_states

        return hidden_states

    @torch.no_grad()
    def fuse_projections(self):
        device = self.to_q.weight.data.device
        dtype = self.to_q.weight.data.dtype

        concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
        in_features = concatenated_weights.shape[1]
        out_features = concatenated_weights.shape[0]
        self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
        self.to_qkv.weight.copy_(concatenated_weights)
        if self.use_bias:
            concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
            self.to_qkv.bias.copy_(concatenated_bias)

        if (
            getattr(self, "add_q_proj", None) is not None
            and getattr(self, "add_k_proj", None) is not None
            and getattr(self, "add_v_proj", None) is not None
        ):
            concatenated_weights = torch.cat(
                [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
            )
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]
            self.to_added_qkv = torch.nn.Linear(
                in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
            )
            self.to_added_qkv.weight.copy_(concatenated_weights)
            if self.added_proj_bias:
                concatenated_bias = torch.cat(
                    [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
                )
                self.to_added_qkv.bias.copy_(concatenated_bias)

        for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"):
            if hasattr(self, layer):
                module = getattr(self, layer)
                module.to("meta")
                delattr(self, layer)

        self.fused_projections = True


class FluxPosEmbed(torch.nn.Module):
    def __init__(self, theta: int, axes_dim: List[int]):
        super().__init__()
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        n_axes = ids.shape[-1]
        cos_out = []
        sin_out = []
        for i in range(n_axes):
            cos, sin = get_1d_rotary_pos_embed(
                self.axes_dim[i],
                ids[:, i],
                theta=self.theta,
                repeat_interleave_real=True,
                use_real=True,
                freqs_dtype=ROPE_PRECISION,
            )
            cos_out.append(cos)
            sin_out.append(sin)
        freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
        freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
        return freqs_cos, freqs_sin


class FluxSingleTransformerBlock(torch.nn.Module):
    def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp_hidden_dim = int(dim * mlp_ratio)

        self.norm = AdaLayerNormZeroSingle(dim)
        self.proj_mlp = torch.nn.Linear(dim, self.mlp_hidden_dim)
        self.act_mlp = torch.nn.GELU(approximate="tanh")
        self.attn = Attention(
            query_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            bias=True,
            qk_norm="rms_norm",
            eps=1e-6,
            pre_only=True,
        )
        self.proj_out = torch.nn.Linear(dim + self.mlp_hidden_dim, dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        temb: torch.Tensor,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
        attn_output = self.attn(
            hidden_states=norm_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )
        attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
        proj_out = self.proj_out(attn_mlp_hidden_states)
        hidden_states = torch.addcmul(hidden_states, gate, proj_out)
        return hidden_states


class FluxTransformerBlock(torch.nn.Module):
    def __init__(
        self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
    ):
        super().__init__()

        self.norm1 = AdaLayerNormZero(dim)
        self.norm1_context = AdaLayerNormZero(dim)

        self.attn = Attention(
            query_dim=dim,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            context_pre_only=False,
            bias=True,
            qk_norm=qk_norm,
            eps=eps,
        )

        self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

        self.norm2_context = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        temb, temb_context = temb.chunk(2, dim=-1)
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
            encoder_hidden_states, emb=temb_context
        )

        attn_output, context_attn_output = self.attn(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )
        hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output)
        encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output)

        norm_hidden_states = self.norm2(hidden_states)
        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
        norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp)
        norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp)

        ff_output = self.ff(norm_hidden_states)
        context_ff_output = self.ff_context(norm_encoder_hidden_states)
        hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output)
        encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output)

        return encoder_hidden_states, hidden_states


class FluxTransformer2DModel(
    ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
    @register_to_config
    def __init__(
        self,
        patch_size: int = 1,
        in_channels: int = 64,
        out_channels: Optional[int] = None,
        num_layers: int = 19,
        num_single_layers: int = 38,
        attention_head_dim: int = 128,
        num_attention_heads: int = 24,
        joint_attention_dim: int = 4096,
        pooled_projection_dim: int = 768,
        guidance_embeds: bool = False,
        axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
    ):
        super().__init__()
        self.out_channels = out_channels or in_channels
        self.inner_dim = num_attention_heads * attention_head_dim

        self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

        text_time_guidance_cls = (
            CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
        )
        self.time_text_embed = text_time_guidance_cls(
            embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
        )

        self.context_embedder = torch.nn.Linear(joint_attention_dim, self.inner_dim)
        self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)

        self.transformer_blocks = torch.nn.ModuleList(
            [
                FluxTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                )
                for _ in range(num_layers)
            ]
        )

        self.single_transformer_blocks = torch.nn.ModuleList(
            [
                FluxSingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                )
                for _ in range(num_single_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = torch.nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        conditioning: torch.Tensor,
        image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
        dt: torch.Tensor,
    ) -> torch.Tensor:
        x_t = hidden_states
        hidden_states = self.x_embedder(hidden_states)

        adaln_linear_dual_stream_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1)
        adaln_linear_single_stream_states = self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1)

        for i, block in enumerate(self.transformer_blocks):
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=adaln_linear_dual_stream_states[i],
                image_rotary_emb=image_rotary_emb,
            )

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        for i, block in enumerate(self.single_transformer_blocks):
            hidden_states = block(
                hidden_states=hidden_states,
                temb=adaln_linear_single_stream_states[i],
                image_rotary_emb=image_rotary_emb,
            )

        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
        hidden_states = self.norm_out(hidden_states, conditioning)
        velocity = self.proj_out(hidden_states)
        x = x_t + dt * velocity
        return x


@torch.no_grad()
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
    for submodule in model.modules():
        if not isinstance(submodule, Attention):
            continue
        submodule.fuse_projections()


@torch.no_grad()
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
    adaln_linear_weights = []
    adaln_linear_biases = []
    for block in model.transformer_blocks:
        adaln_linear_weights.append(block.norm1.linear.weight.data.clone())
        adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone())
        adaln_linear_biases.append(block.norm1.linear.bias.data.clone())
        adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone())
        block.norm1.linear.to("meta")
        block.norm1_context.linear.to("meta")
        del block.norm1.linear, block.norm1_context.linear
    adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
    adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
    in_features = adaln_linear_weights.shape[1]
    out_features = adaln_linear_weights.shape[0]
    model.adaln_linear = torch.nn.Linear(
        in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
    )
    model.adaln_linear.weight.copy_(adaln_linear_weights)
    model.adaln_linear.bias.copy_(adaln_linear_biases)

    adaln_linear_weights = []
    adaln_linear_biases = []
    for block in model.single_transformer_blocks:
        adaln_linear_weights.append(block.norm.linear.weight.data.clone())
        adaln_linear_biases.append(block.norm.linear.bias.data.clone())
        block.norm.linear.to("meta")
        del block.norm.linear
    adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
    adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
    in_features = adaln_linear_weights.shape[1]
    out_features = adaln_linear_weights.shape[0]
    model.adaln_linear_single = torch.nn.Linear(
        in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
    )
    model.adaln_linear_single.weight.copy_(adaln_linear_weights)
    model.adaln_linear_single.bias.copy_(adaln_linear_biases)


def prepare_clip_embeddings(
    text_encoder: CLIPTextModel,
    tokenizer: CLIPTokenizer,
    prompt: str,
    device: torch.device,
    dtype: torch.dtype,
    max_length: int = 77,
) -> torch.Tensor:
    prompt = [prompt]
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_overflowing_tokens=False,
        return_length=False,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
    prompt_embeds = prompt_embeds.pooler_output.to(dtype)
    return prompt_embeds


def prepare_t5_embeddings(
    text_encoder: T5EncoderModel,
    tokenizer: T5TokenizerFast,
    prompt: str,
    device: torch.device,
    dtype: torch.dtype,
    max_length: int = 512,
    enable_prompt_length_bucketing: bool = False,
) -> torch.Tensor:
    prompt = [prompt]
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_length=False,
        return_overflowing_tokens=False,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    if enable_prompt_length_bucketing:
        attention_mask = text_inputs.attention_mask
        num_text_tokens = attention_mask.sum(dim=1).max().item()
        max_length = min(
            SUPPORTED_BUCKET_LENGTHS, key=lambda x: abs(x - num_text_tokens) if x >= num_text_tokens else float("inf")
        )
        text_input_ids = text_input_ids[:, :max_length]
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
    prompt_embeds = prompt_embeds.to(dtype)
    return prompt_embeds, max_length


@functools.lru_cache(maxsize=8)
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype):
    latent_image_ids = torch.zeros(height, width, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    ).contiguous()
    return latent_image_ids.to(device=device, dtype=dtype)


def precompute_guidance_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype):
    embeds = {}
    for guidance_scale in SUPPORTED_GUIDANCE_SCALES:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
        guidance = transformer.time_text_embed.guidance_embedder(
            transformer.time_text_embed.time_proj(guidance * 1000.0).to(dtype)
        )
        embeds[f"{guidance_scale:.1f}"] = guidance
    return embeds


def precompute_timestep_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype):
    embeds = {}
    image_seq_len = LATENT_HEIGHT * LATENT_WIDTH
    for num_inference_steps in range(MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS + 1):
        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
        mu = B + image_seq_len * M
        sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0)
        sigmas = torch.from_numpy(sigmas).to(device, dtype=torch.float32)
        sigmas = torch.cat([sigmas, sigmas.new_zeros(1)])
        timesteps = (sigmas * 1000.0).to(dtype)
        temb = transformer.time_text_embed.time_proj(timesteps)
        temb = transformer.time_text_embed.timestep_embedder(temb.to(dtype))
        embeds[num_inference_steps] = (sigmas, temb)
    return embeds


def precompute_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype, save_dir: str):
    save_dir = pathlib.Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    guidance_path = save_dir / "guidance_embeds.pt"
    timestep_path = save_dir / "timestep_embeds.pt"

    if guidance_path.exists():
        guidance_embeds = torch.load(guidance_path, map_location=device, weights_only=True)
        print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"')
    else:
        guidance_embeds = precompute_guidance_embeds(transformer, device, dtype)
        if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
            torch.save(guidance_embeds, guidance_path.as_posix())
        print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"')

    if timestep_path.exists():
        timestep_embeds = torch.load(timestep_path, map_location=device, weights_only=True)
        print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"')
    else:
        timestep_embeds = precompute_timestep_embeds(transformer, device, dtype)
        if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
            torch.save(timestep_embeds, timestep_path.as_posix())
        print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"')

    return guidance_embeds, timestep_embeds


@torch.compile
def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.silu(x + y + z)


def capture_cudagraph(
    model: FluxTransformer2DModel,
    latents: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    conditioning: torch.Tensor,
    image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
    dt: torch.Tensor,
):
    print("Warming up CUDAGraph capture")
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(2):
            _ = model(
                hidden_states=latents,
                encoder_hidden_states=encoder_hidden_states,
                conditioning=conditioning,
                image_rotary_emb=image_rotary_emb,
                dt=dt,
            )
    torch.cuda.current_stream().wait_stream(s)

    print("Capturing CUDAGraph")
    static_latents = latents.clone()
    static_conditioning = conditioning.clone()
    static_dt = dt.clone()
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        static_x = model(
            hidden_states=static_latents,
            encoder_hidden_states=encoder_hidden_states,
            conditioning=static_conditioning,
            image_rotary_emb=image_rotary_emb,
            dt=static_dt,
        )
    return graph, static_latents, static_conditioning, static_dt, static_x


@torch.inference_mode()
def main(
    model_id: str,
    prompt: str,
    height: int,
    width: int,
    num_inference_steps: int,
    guidance_scale: float,
    compile_mode: str,
    output_file: str,
    cache_dir: Optional[str],
    enable_cudagraph: bool,
    enable_prompt_length_bucketing: bool,
    enable_profiling: bool,
    working_dir: str,
    seed: int,
):
    device = "cuda"
    dtype = torch.bfloat16

    # Load the model components
    transformer = FluxTransformer2DModel.from_pretrained(
        model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype
    )
    text_encoder = CLIPTextModel.from_pretrained(
        model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype
    )
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
    tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype)
    image_processor = VaeImageProcessor(
        vae_scale_factor=SPATIAL_COMPRESSION_RATIO * PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR
    )

    fuse_qkv_(transformer)
    fuse_adaln_linear_(transformer)

    [x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)]
    vae.to(memory_format=torch.channels_last)

    if compile_mode != "none":
        transformer = torch.compile(transformer, mode=compile_mode, fullgraph=True, dynamic=True)
        # text_encoder = torch.compile(text_encoder, mode="default", fullgraph=True, dynamic=True)
        # text_encoder_2 = torch.compile(text_encoder_2, mode="default", fullgraph=True, dynamic=True)
        # We don't compile the VAE due to the implementation calling into non-traceable code paths
        # vae.decode = torch.compile(vae.decode, mode="default", fullgraph=True, dynamic=True)

    # Latent, text, guidance and timestep conditioning preparation
    batch_size = 1
    patch_size = transformer.config.patch_size
    latent_height = height // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR
    latent_width = width // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR

    generator = torch.Generator(device=device).manual_seed(seed)

    guidance_embeds, timestep_embeds = precompute_embeds(transformer, device, dtype, working_dir)

    latents = torch.randn(
        (batch_size, latent_height * latent_width, transformer.config.in_channels),
        dtype=dtype,
        device=device,
        generator=generator,
    )
    pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype)
    encoder_hidden_states, num_text_tokens = prepare_t5_embeddings(
        text_encoder_2, tokenizer_2, prompt, device, dtype, T5_SEQUENCE_LENGTH, enable_prompt_length_bucketing
    )

    # <precompute>
    guidance_conditioning = guidance_embeds[f"{guidance_scale:.1f}"]
    sigmas, timestep_conditioning = timestep_embeds[num_inference_steps]
    pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections)
    encoder_hidden_states = transformer.context_embedder(encoder_hidden_states)
    # </precompute>

    img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype)
    txt_ids = torch.zeros(num_text_tokens, 3).to(device=device, dtype=dtype)

    if cp_options.mesh is not None:
        # Note: clone seems to be a must here otherwise there is a recompilation related to storage offsets (which
        # tells you to use torch._dynamo.decorators.mark_unbacked) /shrug
        img_ids = EquipartitionSharder.shard(img_ids, dim=0, mesh=cp_options._flattened_mesh).clone()
        txt_ids = EquipartitionSharder.shard(txt_ids, dim=0, mesh=cp_options._flattened_mesh).clone()
        latents = EquipartitionSharder.shard(latents, dim=1, mesh=cp_options._flattened_mesh).clone()
        encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, dim=1, mesh=cp_options._flattened_mesh).clone()

    ids = torch.cat([txt_ids, img_ids], dim=0).float()
    image_rotary_emb = transformer.pos_embed(ids)
    dt = sigmas[1:] - sigmas[:-1]

    print("Warming up the model")
    for _ in range(2):
        conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections)
        _ = transformer(
            hidden_states=latents,
            encoder_hidden_states=encoder_hidden_states,
            conditioning=conditioning,
            image_rotary_emb=image_rotary_emb,
            dt=dt[0],
        )
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    context = (
        profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True)
        if enable_profiling
        else contextlib.nullcontext()
    )

    if not enable_cudagraph:
        with context as ctx:
            start_event.record()
            for i in range(num_inference_steps):
                conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections)
                latents = transformer(
                    hidden_states=latents,
                    encoder_hidden_states=encoder_hidden_states,
                    conditioning=conditioning,
                    image_rotary_emb=image_rotary_emb,
                    dt=dt[i],
                )
            end_event.record()
            torch.cuda.synchronize()
    else:
        conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections)
        graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph(
            transformer,
            latents,
            encoder_hidden_states,
            conditioning,
            image_rotary_emb,
            dt[0],
        )

        with context as ctx:
            start_event.record()
            static_x.copy_(latents)
            for i in range(num_inference_steps):
                conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections)
                torch._foreach_copy_(
                    (static_latents, static_conditioning, static_dt),
                    (static_x, conditioning, dt[i]),
                    non_blocking=True,
                )
                graph.replay()
            end_event.record()
            torch.cuda.synchronize()
        latents = static_x
    
    total_time = start_event.elapsed_time(end_event) / 1000.0

    if cp_options.mesh is not None:
        latents = EquipartitionSharder.unshard(latents, dim=1, mesh=cp_options._flattened_mesh)

    if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
        print(f"time: {total_time:.2f}s")

        if enable_profiling:
            print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
            ctx.export_chrome_trace("dump_benchmark_flux.json")

        latents = latents.reshape(batch_size, latent_height, latent_width, -1, 2, 2)
        latents = latents.permute(0, 3, 1, 4, 2, 5)
        latents = latents.flatten(4, 5).flatten(2, 3)
        latents = latents.to(memory_format=torch.channels_last)

        latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
        image = vae.decode(latents, return_dict=False)[0]
        image = image_processor.postprocess(image, output_type="pil")[0]
        image.save(output_file)


class EquipartitionSharder:
    @classmethod
    def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        assert tensor.size()[dim] % mesh.size() == 0

        # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
        # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
        
        return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]

    @classmethod
    def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
        tensor = tensor.contiguous()
        tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
        return tensor


# Reference:
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
def _wait_tensor(tensor):
    if isinstance(tensor, funcol.AsyncCollectiveTensor):
        tensor = tensor.wait()
    return tensor


def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
    shape = x.shape
    # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
    # to benchmark triton codegen fails somewhere:
    # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
    # ValueError: Tensors must be contiguous
    x = x.flatten()
    x = funcol.all_to_all_single(x, None, None, group)
    x = x.reshape(shape)
    x = _wait_tensor(x)
    return x


def _templated_ring_attention(query, key, value):
    ring_mesh = cp_options.mesh["ring"]
    rank = cp_options._ring_local_rank
    world_size = cp_options.ring_degree

    if world_size == 1:
        return cp_options.attention_op(query, key, value)
    
    next_rank = (rank + 1) % world_size
    prev_out = prev_lse = None
    
    kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
    kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
    kv_buffer = kv_buffer.chunk(world_size)

    for i in range(world_size):
        if i > 0:
            kv = kv_buffer[next_rank]
            key = kv[:key.numel()].reshape_as(key)
            value = kv[key.numel():].reshape_as(value)
            next_rank = (next_rank + 1) % world_size

        out, lse = cp_options.attention_op(query, key, value)

        if cp_options.convert_to_fp32:
            out = out.to(torch.float32)
            lse = lse.to(torch.float32)

        lse = lse.unsqueeze(-1)
        if prev_out is not None:
            out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
            lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
        prev_out = out
        prev_lse = lse

    out = out.to(query.dtype)
    lse = lse.squeeze(-1)
    return out, lse


def _templated_ulysses_attention(query, key, value, *, return_lse: bool = False):
    ulysses_mesh = cp_options.mesh["ulysses"]
    world_size = cp_options.ulysses_degree
    group = ulysses_mesh.get_group()
    
    if world_size == 1:
        return cp_options.attention_op(query, key, value)

    B, S_LOCAL, H, D = query.shape
    H_LOCAL = H // world_size
    query, key, value = (
        x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).clone()
        for x in (query, key, value)
    )
    query, key, value = (
        _all_to_all_single(x, group)
        for x in (query, key, value)
    )
    query, key, value = (
        x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
        for x in (query, key, value)
    )
    out, lse = cp_options.attention_op(query, key, value)
    out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
    out = _all_to_all_single(out, group)
    out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
    if return_lse:
        lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
        lse = _all_to_all_single(lse, group)
        lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
    else:
        lse = None
    return out, lse


# TODO: currently produces incorrect results (for example, with CP=4, ring=2, ulysses=2, half output is our expected image,
# and other half is some completely different)
# def _templated_unified_attention(query, key, value):
#     ulysses_mesh = cp_options.mesh["ulysses"]
#     ulysses_size = ulysses_mesh.size()
#     ulysses_group = ulysses_mesh.get_group()

#     B, S_LOCAL, H, D = query.shape
#     H_LOCAL = H // ulysses_size
#     query, key, value = (
#         x.reshape(B, S_LOCAL, ulysses_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
#         for x in (query, key, value)
#     )
#     query, key, value = (
#         wait_tensor(funcol.all_to_all_single(x, None, None, group=ulysses_group))
#         for x in (query, key, value)
#     )
#     query, key, value = (
#         x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
#         for x in (query, key, value)
#     )
#     out, lse = _templated_ring_attention(query, key, value)
#     out = out.reshape(B, ulysses_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
#     lse = lse.reshape(B, ulysses_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
#     out = wait_tensor(funcol.all_to_all_single(out, None, None, group=ulysses_group))
#     lse = wait_tensor(funcol.all_to_all_single(lse, None, None, group=ulysses_group))
#     out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
#     lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
#     return out, lse


# For fullgraph=True tracing to be compatible
@torch.library.custom_op("flash_attn_3::_flash_attn_forward_original", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    out, lse = flash_attn_3_func(query, key, value)
    lse = lse.permute(0, 2, 1)
    return out, lse


@torch.library.register_fake("flash_attn_3::_flash_attn_forward_original")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    batch_size, seq_len, num_heads, head_dim = query.shape
    lse_shape = (batch_size, seq_len, num_heads)
    return torch.empty_like(query), query.new_empty(lse_shape)


@torch.library.custom_op("flash_attn_3::_flash_attn_forward_hf", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hf(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    out, lse = flash_attn_3_hf.flash_attn_func(query, key, value, causal=False)
    return out


@torch.library.register_fake("flash_attn_3::_flash_attn_forward_hf")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(query)


def _attention_torch_cudnn(query, key, value):
    query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
        torch.ops.aten._scaled_dot_product_cudnn_attention(
            query=query,
            key=key,
            value=value,
            attn_bias=None,
            compute_log_sumexp=True,
        )
    )
    out = out.transpose(1, 2).contiguous()
    lse = lse.transpose(1, 2).contiguous()
    return out, lse


def _attention_flash_attn_2(query, key, value):
    out, lse, _ = flash_attn_func(query, key, value, return_attn_probs=True)
    lse = lse.permute(0, 2, 1)
    return out, lse


def _attention_flash_attn_3_original(query, key, value):
    out = _wrapped_flash_attn_3_original(query, key, value)
    return out


def _attention_flash_attn_3_hf(query, key, value):
    out = _wrapped_flash_attn_3_hf(query, key, value)
    return out


def _download_hf_flash_attn_3():
    global flash_attn_3_hf
    flash_attn_3_hf = get_kernel("kernels-community/flash-attn3")


def get_args():
    DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-dev"
    DEFAULT_PROMPT = "The King of Hearts card transforms into a 3D hologram that appears to be made of cosmic energy. As the King emerges, stars and galaxies swirl around him, creating a sense of traveling through the universe. The King's attire is adorned with celestial patterns, and his crown is a glowing star cluster. The hologram floats in front of you, with the background shifting through different cosmic scenes, from nebulae to black holes. Atmosphere: Perfect for space-themed events, science fiction conventions, or futuristic tech expos."

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID)
    parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT)
    parser.add_argument("--height", type=int, default=1024)
    parser.add_argument("--width", type=int, default=1024)
    parser.add_argument("--num_inference_steps", type=int, default=28)
    parser.add_argument("--guidance_scale", type=float, default=4.0)
    parser.add_argument(
        "--compile_mode",
        type=str,
        default="none",
        choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
    )
    parser.add_argument("--output_file", type=str, default="output.png")
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("--enable_fp32_rope", action="store_true")
    parser.add_argument("--enable_cudagraph", action="store_true")
    parser.add_argument("--enable_prompt_length_bucketing", action="store_true")
    parser.add_argument("--attention_provider", type=str, default="cudnn", choices=["cudnn", "fa2", "fa3", "fa3_original"])
    parser.add_argument("--ring_degree", type=int, default=1)
    parser.add_argument("--ulysses_degree", type=int, default=1)
    parser.add_argument("--enable_profiling", action="store_true")
    parser.add_argument("--disable_tf32", action="store_true")
    parser.add_argument("--disable_flags", action="store_true")
    parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation")
    parser.add_argument("--seed", type=int, default=31337)

    args = parser.parse_args()
    return args


def setup_config(args):
    torch.manual_seed(args.seed)

    if args.enable_fp32_rope:
        global ROPE_PRECISION
        ROPE_PRECISION = torch.float32

    if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]:
        raise ValueError(
            "Only compiled modes 'none', 'default', and 'max-autotune-no-cudagraphs' are supported with CUDAGraphs."
        )
    
    global ATTENTION_OP
    if args.attention_provider == "cudnn":
        ATTENTION_OP = _attention_torch_cudnn
    elif args.attention_provider == "fa2":
        ATTENTION_OP = _attention_flash_attn_2
    elif args.attention_provider == "fa3":
        _download_hf_flash_attn_3()
        ATTENTION_OP = _attention_flash_attn_3_hf
    elif args.attention_provider == "fa3_original":
        ATTENTION_OP = _attention_flash_attn_3_original
    else:
        assert False

    if args.enable_profiling and args.enable_cudagraph:
        torch.profiler._utils._init_for_cuda_graphs()

    if not args.disable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    if not args.disable_flags:
        apply_flags()
    
    global MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS
    if args.num_inference_steps < MIN_INFERENCE_STEPS or args.num_inference_steps > MAX_INFERENCE_STEPS:
        raise ValueError(f"`num_inference_steps` must be equal or between {MIN_INFERENCE_STEPS} and {MAX_INFERENCE_STEPS}.")


def setup_distributed(ring_degree: int, ulysses_degree: int, compile_mode: str):
    global ATTENTION_OP

    dist.init_process_group("nccl")
    torch.cuda.set_device(torch.device("cuda", dist.get_rank()))

    if ring_degree * ulysses_degree != dist.get_world_size():
        raise ValueError(f"`ring_degree * ulysses_degree` must equal the world size {dist.get_world_size()}.")

    mesh_names = ["ring", "ulysses"]
    mesh_dims = [ring_degree, ulysses_degree]
    mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names)
    
    cp_options.ring_degree = ring_degree
    cp_options.ulysses_degree = ulysses_degree
    cp_options.mesh = mesh
    cp_options.convert_to_fp32 = True
    cp_options.attention_op = ATTENTION_OP
    cp_options._flattened_mesh = mesh._flatten()
    cp_options._ring_mesh = mesh["ring"]
    cp_options._ulysses_mesh = mesh["ulysses"]
    cp_options._ring_local_rank = cp_options._ring_mesh.get_local_rank()
    cp_options._ulysses_local_rank = cp_options._ulysses_mesh.get_local_rank()

    if ring_degree > 1 and ulysses_degree > 1:
        raise ValueError("The current implementation is incorrect for unified attention and needs to be fixed.")
        # cp_options.mode = "unified"
        # ATTENTION_OP = _templated_unified_attention
    elif ulysses_degree > 1:
        cp_options.mode = "ulysses"
        ATTENTION_OP = _templated_ulysses_attention
    else:
        cp_options.mode = "ring"
        ATTENTION_OP = _templated_ring_attention
    
    if compile_mode != "none":
        torch._dynamo.config.suppress_errors = True
    
    torch._inductor.config.reorder_for_compute_comm_overlap = True


if __name__ == "__main__":
    args = get_args()

    try:
        setup_config(args)
        setup_distributed(args.ring_degree, args.ulysses_degree, args.compile_mode)
        main(
            model_id=args.model_id,
            prompt=args.prompt,
            height=args.height,
            width=args.width,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            compile_mode=args.compile_mode,
            output_file=args.output_file,
            cache_dir=args.cache_dir,
            enable_cudagraph=args.enable_cudagraph,
            enable_prompt_length_bucketing=args.enable_prompt_length_bucketing,
            enable_profiling=args.enable_profiling,
            working_dir=args.working_dir,
            seed=args.seed,
        )
    except Exception as e:
        print(f"An error occurred: {e}")
        if dist.is_initialized():
            torch.distributed.breakpoint()
        raise
    finally:
        if dist.is_initialized():
            dist.destroy_process_group()
A100 H100
image image

TODO: link to blog post

Explanation

Each model should define a _cp_plan attribute that contains information on how to shard/gather tensors at different stages of the forward. Let's try to understand with an example using QwenImage:

_cp_plan = {
    "": {
        "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
        "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
        "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
    },
    "pos_embed": {
        0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
        1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
    },
    "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}

The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be split/gathered according to this at the respective module level. Here, the following happens:

  • "": we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before the actual forward logic of the QwenImageTransformer2DModel is run, we will split the inputs)
  • "pos_embed": we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs), we can individually specify how they should be split
  • "proj_out": before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear layer forward has run).

ContextParallelInput: specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to

ContextParallelOutput: specifies how to gather the input tensor in the post-forward hook in the layer it is attached to

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jul 16, 2025
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Just a few questions and nits.



@contextlib.contextmanager
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm wondering if we need this additional context manager. I think the name implies that it's parallelizing the components, when it's really a validation step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this function is what sets the dispatcher to perform context parallel templated attention instead of following the normal attention path in non-CP case

_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config

Without this, we have to hook into perform the assignment in the pre-forward hook of the model. But that has compatibility issues with torch dynamo (tracing feels with a setattr/getattr-related error. It was our previous approach in this commit, but it had many compatibility issues for distributed training. Instead, explicitly doing this outside the forward is ideal for setting up all the required information about parallelism.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but setting parallel_config in this way leads to this issue no?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work as always @a-r-r-o-w. Good to merge once tests are passing.

@stevhliu stevhliu mentioned this pull request Sep 15, 2025
@a-r-r-o-w
Copy link
Contributor Author

Thanks @DN6! I'm AFK from my personal laptop so can't make changes for another 2-3 days. Sorry about the delay!

Co-authored-by: Aryan <aryan@huggingface.co>
@a-r-r-o-w
Copy link
Contributor Author

@DN6 I'm unable to push any changes to this branch since it's on the official repo instead of the my personal fork. I think if you added an entry to _toctree.yaml, it should fix the failing test

def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will follow suit and make it work with FA3 Hub as well.

@sayakpaul
Copy link
Member

Just pushed an entry to the TOC. Cc: @stevhliu for now, I have added it to "Inference optimization" section. But feel free to change that in your documentation PR.

@a-r-r-o-w sorry for the delay here. We should be able to merge this now. Thanks for setting the foundations here :)

@sayakpaul sayakpaul merged commit dcb6dd9 into main Sep 24, 2025
33 of 34 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.36 Sep 24, 2025
@sayakpaul sayakpaul deleted the attn-dispatcher-cp-and-training branch September 24, 2025 13:33
sayakpaul pushed a commit that referenced this pull request Oct 15, 2025
#12206)

* Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes #12195

* Use dummy decorators approach for PyTorch version compatibility

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from #11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes

* Update src/diffusers/models/attention_dispatch.py

Update all the decorator usages

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Move version check to top of file and use private naming as requested

* Apply style fixes

---------

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

roadmap Add to current release roadmap

Projects

Development

Successfully merging this pull request may close these issues.

4 participants