KEMBAR78
[core] Pyramid Attention Broadcast by a-r-r-o-w · Pull Request #9562 · huggingface/diffusers · GitHub
Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Oct 1, 2024

What does this PR do?

Adds support for Pyramid Attention Broadcast.

model_id cache_method time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
latte none 31.742 12.88 13.189 12.888 14.344
latte fastercache 24.288 12.882 13.17 12.891 22.199
latte pyramid_attention_broadcast 27.531 12.882 13.17 12.891 20.23
cogvideox-1.0 none 245.939 19.66 19.678 19.671 24.426
cogvideox-1.0 fastercache 159.031 19.661 19.678 19.672 40.721
cogvideox-1.0 pyramid_attention_broadcast 184.013 19.661 19.678 19.672 32.57
mochi none 437.327 28.411 28.65 28.421 36.062
mochi fastercache 358.871 28.411 28.648 28.422 36.062
mochi pyramid_attention_broadcast 324.051 28.411 28.65 28.421 52.088
hunyuan_video none 72.628 38.577 38.672 38.587 41.141
hunyuan_video pyramid_attention_broadcast 63.892 38.578 38.672 38.587 44.785
flux none 16.802 31.44 31.451 31.448 32.023
flux pyramid_attention_broadcast 13.719 31.439 31.451 31.447 32.832

Usage

import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

config = PyramidAttentionBroadcastConfig(
    spatial_attention_block_skip_range=2,
    spatial_attention_timestep_skip_range=(100, 800),
    current_timestep_callback=lambda: pipe._current_timestep,
)
apply_pyramid_attention_broadcast(pipe.transformer, config)

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
Benchmark code
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch

if branch.name in ["pyramid-attention-broadcast", "pyramid-attention-rewrite-2"]:
    from diffusers import (
        apply_pyramid_attention_broadcast,
        PyramidAttentionBroadcastConfig,
    )
elif branch.name in ["fastercache"]:
    from diffusers.pipelines.fastercache_utils import apply_fastercache, FasterCacheConfig


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": (
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "a cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_allegro_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            cross_attention_block_skip_range=6,
            spatial_attention_timestep_skip_range=(100, 700),
            cross_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
            cross_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "none":
        return None


def prepare_cogvideox_1_0_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks"],
            attention_weight_callback=lambda _: 0.3,
            tensor_format="BFCHW",
        )
    elif cache_method == "none":
        return None


def prepare_flux_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=4,
            spatial_attention_timestep_skip_range=(-1, 961),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            tensor_format="BCHW",
        )
    elif cache_method == "none":
        return None


def prepare_hunyuan_video_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=4,
            spatial_attention_timestep_skip_range=(99, 941),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            tensor_format="BCFHW",
        )
    elif cache_method == "none":
        return None


def prepare_latte_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            temporal_attention_block_skip_range=3,
            cross_attention_block_skip_range=6,
            spatial_attention_timestep_skip_range=(100, 700),
            temporal_attention_timestep_skip_range=(100, 800),
            cross_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
            temporal_attention_block_identifiers=["temporal_transformer_blocks"],
            cross_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            temporal_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            temporal_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks.*attn1"],
            temporal_attention_block_identifiers=["temporal_transformer_blocks"],
        )
    elif cache_method == "none":
        return None


def prepare_mochi_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(400, 987),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 981),
            low_frequency_weight_update_timestep_range=(301, 961),
            high_frequency_weight_update_timestep_range=(-1, 851),
            unconditional_batch_skip_range=4,
            unconditional_batch_timestep_skip_range=(-1, 975),
            spatial_attention_block_identifiers=["transformer_blocks"],
            attention_weight_callback=lambda _: 0.6,
        )
    elif cache_method == "none":
        return None


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


MODEL_MAPPING = {
    "allegro": {
        "prepare": prepare_allegro,
        "config": prepare_allegro_config,
        "decode": decode_allegro,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "config": prepare_cogvideox_1_0_config,
        "decode": decode_cogvideox_1_0,
    },
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "config": prepare_hunyuan_video_config,
        "decode": decode_hunyuan_video,
    },
    "latte": {
        "prepare": prepare_latte,
        "config": prepare_latte_config,
        "decode": decode_latte,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "config": prepare_mochi_config,
        "decode": decode_mochi,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator("cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str, compile: bool = False):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype, compile=compile)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method)
        
        if cache_method == "pyramid_attention_broadcast":
            config.current_timestep_callback = lambda: pipe._current_timestep
        
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_fastercache(pipe, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "compile": compile,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "compile": compile,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype, args.compile)
Flux
Mochi
mochi---dtype-bf16---cache_method-none---compile-False.mp4
mochi---dtype-bf16---cache_method-pyramid_attention_broadcast---compile-False.mp4
Hunyuan Video
hunyuan_video---dtype-bf16---cache_method-none---compile-False.mp4
hunyuan_video---dtype-bf16---cache_method-pyramid_attention_broadcast---compile-False.mp4
CogVideoX-2b T2V
cogvideox_2b.mp4
cogvideox_pab_2b.mp4
CogVideoX-5b T2V
cogvideox_5b.mp4
cogvideox_pab_5b.mp4
CogVideoX-5b I2V
cogvideox_5b_i2v.mp4
cogvideox_pab_5b_i2v.mp4
Latte
latte---dtype-fp16---cache_method-none---compile-False.mp4
latte---dtype-fp16---cache_method-pyramid_attention_broadcast---compile-False.mp4

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul

@oahzxl for PAB, @zRzRzRzRzRzRzR for CogVideoX related changes, @maxin-cn for Latte related changes

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

a-r-r-o-w and others added 4 commits October 3, 2024 08:34
Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
@a-r-r-o-w
Copy link
Contributor Author

I can't seem to replicate the results for PAB on CogVideoX-5b T2V or I2V. This is what I get:

CogVideoX-5b T2V
cogvideox_5b.mp4
cogvideox_pab_5b.mp4
CogVideoX-5b I2V
cogvideox_5b_i2v.mp4
cogvideox_pab_5b_i2v.mp4

@oahzxl Would you be able to give this a review when free? I'm unable to figure out what I'm doing wrong that's causing poor results in these cases. Thank you!

@oahzxl
Copy link
Contributor

oahzxl commented Oct 3, 2024

sure, thanks for your code! i guess it may be related with pos embed or encoder concat of 5b model. i can have a look at the code soon!

@oahzxl
Copy link
Contributor

oahzxl commented Oct 3, 2024

hi, i have done some experiments and here are my conclusions:

i first try a simple implementation

the org attention is:

        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )

for simplicty, i just add pab's logic here:

        # in init
        self.attn_count = 0
        self.last_attn = None
        
        ...

        # in forward
        if (10 < self.attn_count < 45) and (self.attn_count % 2 != 0):
            attn_hidden_states, attn_encoder_hidden_states = self.last_attn
        else:
            attn_hidden_states, attn_encoder_hidden_states = self.attn1(
                hidden_states=norm_hidden_states,
                encoder_hidden_states=norm_encoder_hidden_states,
                image_rotary_emb=image_rotary_emb,
            )
            self.last_attn = attn_hidden_states, attn_encoder_hidden_states

this should be exactly the same as the logic in pab processor.

then i find pab will be numerically unstable with fp16 for cogvideox-5b. so i change to bfloat16, and it works!

output-bf16-new.mp4

->> so the first problem is float16!

  1. then i test pab processor

but fail even if i use bfloat16

i find even i set spatial_attn_skip_range to 1 (which means no broadcast), it will also generate random noise.

->> so i think the second problem is in processor, but no clue for now

hope it can help you!

@a-r-r-o-w
Copy link
Contributor Author

Thank you so much for the investigation! I think I found the bug. This line checks if the processor signature supports a specific keyword arguments before passing them. In this case, since we replace the attention processor with PyramidAttentionBroadcastAttentionProcessor, which only has args and kwargs, it drops the image_rotary_emb kwargs necessary for generation. So, RoPE embeddings are not passed at all causing bad video.

@oahzxl
Copy link
Contributor

oahzxl commented Oct 3, 2024

glad i can help :) !

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review October 3, 2024 19:23
@a-r-r-o-w a-r-r-o-w requested review from stevhliu and yiyixuxu October 4, 2024 06:52
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The numbers are extremely promising!

I want to brainstorm a bit about how we should incorporate PAB design-wise.

IIUC, PAB can be applied at a model-level and it rejigs the attention computation of the concerned model. IMO, this is a bit similar to how we do QKV fusion. Entry point to QKV fusion can either be from a pipeline or from a model (if there's support).

If this is correct, then I wonder if supporting PAB through a Mixin class makes the most elegant design as opposed to enabling it via set_attn_processor().

Or are we relying on a Mixin because we need to depend on pipeline-level attributes which may not be suitable for a model?

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks for adding docs for this method! Same comments apply to latte.md :)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it will be a nice feature, but not very sure about the design.

  • the Attention processor wrapper is not aligned with our design, we should just make custom attention processors (even though we might have to make one for each model that has non default attention processor)
  • For another thing, I think this would also won't be compatible with torch.compile, no? I think we should consider a design similar to https://github.com/huggingface/diffusers/pull/9524/files. We can probably store the attention output cache (a dict) on pipeline and pass as cross_attention_kwargs on each iterations (just putting the ideas here. not something I have already carefully thought through, so it might not work. feel free to brainstorm)

return False

should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
return not should_compute_attention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be misunderstanding here, but shouldn't we just return should_compute_attention directly here? Why use return not should_compute_attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is used to determine whether attention computation should be skipped or not. So, if skip_callback were to return True, it means that should_compute_attention had to have been False, and vice versa, so this is correct.

@a-r-r-o-w
Copy link
Contributor Author

@DN6 Addressed the review comments. Could you give this another look?

@a-r-r-o-w
Copy link
Contributor Author

a-r-r-o-w commented Jan 21, 2025

We need to make some more updates here before merging to address the case of using multiple hooks at once. The current implementation does not really work, if say both FP8 and PAB are enabled together. I will take it up in this PR before merging after layerwise upcasting is merged: #10347

This has already been addressed in group offloading PR but that will take some more time to complete: #10503

@a-r-r-o-w
Copy link
Contributor Author

With the latest changes, it is now possible to use multiple forward-modifying hooks now. Here's an example with FP8 layerwise-upcasting and PAB:

import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

config = PyramidAttentionBroadcastConfig(
    spatial_attention_block_skip_range=2,
    spatial_attention_timestep_skip_range=(150, 700),
    current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
pipe.transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
output.mp4

Comment on lines 136 to 149
forward = self._module_ref.forward

fn_ref = FunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.old_forward = forward

if hasattr(hook, "new_forward"):
fn_ref.overwritten_forward = forward
fn_ref.old_forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)

rewritten_forward = create_new_forward(fn_ref)
Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Some major changes made to the hooks addition/removal process to be able to support:

  • adding multiple hooks to affect the forward pass
  • remove hooks arbitrarily (out-of-order is supported as well)

Please take a look when you can. Happy to answer any questions and iterate further if needed

Comment on lines 166 to 176
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]

old_forward = fn_ref.old_forward
if fn_ref.overwritten_forward is not None:
old_forward = fn_ref.overwritten_forward

if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].old_forward = old_forward
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may look a bit weird - why are we assigning the function reference of next index to the old forward of the hook being removed?

TLDR; Hook invocation order is the reverse of the order in which they are added. Please take a look at the tests related to invocation for better understanding.

When we add hook A followed by hook B, the execution order of methods looks like:

forward:
  pre_forward of B
    pre_forward of A
      actual_module_orginal_forward
    post_forward of A
  post_forward of B

In this example, the function references would have A-related ones at index 0, and B related ones at index 1. Removing hook A requires pointing the B->old_forward to actual_module_forward. Removing hook B requires pointing module->forward to new_forward(A). We handle both cases here.

Let's take a more complex example to understand better. We add hook A that only has pre/post-forward. We add hook B that has a new_forward implementation. We add hook C that only has pre/post-forward. The invocation order would be

forward:
  pre_forward of C           
    pre_forward of B       /> pre_forward of A
      new_forward of B  --/     actual_module_original_forward
    post_forward of B     \</ post_forward of A
  post_forward of C

In this example, the function references would have A-related ones at index 0, and B related ones at index 1, and C related ones at index 2. Removing hook A requires pointing B->old_forward (since we overwrote the original forward implementation by making use of a new_forward method) to actual_module_original_forward. Removing hook B requires pointing C->old_forward to new_forward(A). Removing hook C requires pointing module->forward to new_forward(B)

On a separate note from the explanation, the invocation design here is very friendly with parellism too IMO, so we can eventually introduce different parallel methods utilizing the same hook design, without being invasive in the actual modeling implementations themselves.

There are a few tests added to make sure execution order is correct.

Comment on lines 164 to 167
if should_compute_attention:
output = self.fn_ref.overwritten_forward(*args, **kwargs)
else:
output = self.state.cache
Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a hook implements a new_forward method, it can choose to make a call to the forward method it overwrote. The overwritten function is always stored in the overwritten_forward attribute of FunctionReference objects

return output


class HookTests(unittest.TestCase):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some basic fast tests to check simple functionality of the hooks on dummy model. This is necessary to make sure all methods are being invoked correctly, and that the hooks are behaving in a predictable manner when added, or arbitrarily removed out-of-order.

We can also support out-of-order hook addition, but currently there is no such use case, so support has not been added.

a-r-r-o-w and others added 3 commits January 28, 2025 03:51
* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
@a-r-r-o-w
Copy link
Contributor Author

I think we're good to merge now and also got the approval from Dhruv after working together on latest changes! Thanks for the patience and the reviews everyone 🤗 Will merge once CI is green and wrap up the open cache PRs

@oahzxl Congratulations on the success of your new work - Data centric parallel! I also really liked reading about the pyramid activation checkpointing that was introduced in VideoSys. Thanks for your patience and help, and also for your work that inspired multiple other papers researching caching mechanism specific to video models. We will be sure to integrate as much as possible to make the methods more easily accessible :)

@a-r-r-o-w a-r-r-o-w merged commit 658e24e into main Jan 27, 2025
15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the pyramid-attention-broadcast branch January 27, 2025 23:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

roadmap Add to current release roadmap

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

7 participants