KEMBAR78
[core] Hunyuan Video by a-r-r-o-w · Pull Request #10136 · huggingface/diffusers · GitHub
Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Dec 5, 2024

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "tencent/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)

As the official weights are not hosted yet, the above code does not still work. We will work with the Hunyuan team to host weights asap. Thanks for being patient!

hunyuan-output.mp4

@a-r-r-o-w
Copy link
Contributor Author

Test code for running inference with original codebase (and to save latents at each step for comparison when doing the conversion):

code
# import argparse
# import torch
# from transformers import (
#     AutoProcessor,
#     LlavaForConditionalGeneration,
# )


# def preprocess_text_encoder_tokenizer(args):

#     processor = AutoProcessor.from_pretrained(args.input_dir)
#     model = LlavaForConditionalGeneration.from_pretrained(
#         args.input_dir,
#         torch_dtype=torch.float16,
#         low_cpu_mem_usage=True,
#     ).to(0)

#     print(model.language_model)
#     print()
#     print(processor.tokenizer)
    
#     model.language_model.save_pretrained(
#         f"{args.output_dir}/text_encoder"
#     )
#     processor.tokenizer.save_pretrained(
#         f"{args.output_dir}/tokenizer"
#     )

# if __name__ == "__main__":

#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--input_dir",
#         type=str,
#         required=True,
#         help="The path to the llava-llama-3-8b-v1_1-transformers.",
#     )
#     parser.add_argument(
#         "--output_dir",
#         type=str,
#         default="",
#         help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
#         "if '', the parent dir of output will be the same as input dir.",
#     )
#     args = parser.parse_args()

#     if len(args.output_dir) == 0:
#         args.output_dir = "/".join(args.input_dir.split("/")[:-1])

#     preprocess_text_encoder_tokenizer(args)

from typing import List, Tuple, Union

import torch
from accelerate import init_empty_weights
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.hunyuan_video.text_encoder import TextEncoder
from diffusers.utils import export_to_video
from PIL import Image

transformer_state_dict_path = "/raid/aryan/hyvideo-original/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
with init_empty_weights():
    transformer = HunyuanVideoTransformer3DModel(
        in_channels=16,
        out_channels=16,
        mm_double_blocks_depth=20,
        mm_single_blocks_depth=40,
        rope_dim_list=[16, 56, 56],
        hidden_size=3072,
        heads_num=24,
        mlp_width_ratio=4,
        guidance_embed=True,
    )
transformer.load_state_dict(torch.load(transformer_state_dict_path, weights_only=True)["module"], strict=True, assign=True)
transformer = transformer.to("cuda", torch.bfloat16)

vae_state_dict_path = "/raid/aryan/hyvideo-original/hunyuan-video-t2v-720p/vae/pytorch_model.pt"
vae_config = {
    "act_fn": "silu",
    "block_out_channels": [
        128,
        256,
        512,
        512
    ],
    "down_block_types": [
        "DownEncoderBlockCausal3D",
        "DownEncoderBlockCausal3D",
        "DownEncoderBlockCausal3D",
        "DownEncoderBlockCausal3D"
    ],
    "in_channels": 3,
    "latent_channels": 16,
    "layers_per_block": 2,
    "norm_num_groups": 32,
    "out_channels": 3,
    "sample_size": 256,
    "sample_tsize": 64,
    "up_block_types": [
        "UpDecoderBlockCausal3D",
        "UpDecoderBlockCausal3D",
        "UpDecoderBlockCausal3D",
        "UpDecoderBlockCausal3D"
    ],
    "scaling_factor": 0.476986,
    "time_compression_ratio": 4,
    "mid_block_add_attention": True,
}
with init_empty_weights():
    vae = AutoencoderKLHunyuanVideo(**vae_config)
vae_state_dict = torch.load(vae_state_dict_path, weights_only=True)
vae_state_dict = {k.replace("vae.", ""): v for k, v in vae_state_dict.items()}
vae.load_state_dict(vae_state_dict, strict=True, assign=True)
vae = vae.to("cuda", torch.float16)


PROMPT_TEMPLATE_ENCODE = (
    "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
    "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
    "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
) 
PROMPT_TEMPLATE_ENCODE_VIDEO = (
    "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
    "1. The main content and theme of the video."
    "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
    "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
    "4. background environment, light, style and atmosphere."
    "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
    "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)

prompt_template = {
    "template": PROMPT_TEMPLATE_ENCODE,
    "crop_start": 36,
}
prompt_template_video = {
    "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
    "crop_start": 95,
}

text_len = 256
max_length = text_len + prompt_template_video.get("crop_start", 0)

text_encoder = TextEncoder(
    text_encoder_type="llm",
    max_length=max_length,
    text_encoder_precision="fp16",
    text_encoder_path="/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder",
    tokenizer_type="llm",
    tokenizer_path="/raid/aryan/llava-llama-3-8b-v1_1-extracted/tokenizer",
    prompt_template=prompt_template,
    prompt_template_video=prompt_template_video,
    hidden_state_skip_layer=2,
    apply_final_norm=False,
    reproduce=True,
).to("cuda")
text_encoder_2 = TextEncoder(
    text_encoder_type="clipL",
    max_length=77,
    text_encoder_precision="fp16",
    text_encoder_path="/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2",
    tokenizer_type="clipL",
    reproduce=True,
).to("cuda")

scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)

pipe = HunyuanVideoPipeline(
    vae=vae,
    text_encoder=text_encoder,
    transformer=transformer,
    scheduler=scheduler,
    text_encoder_2=text_encoder_2,
)
pipe.to("cuda")
torch.cuda.synchronize()


def _to_tuple(x, dim=2):
    if isinstance(x, int):
        return (x,) * dim
    elif len(x) == dim:
        return x
    else:
        raise ValueError(f"Expected length {dim} or int, but got {x}")


def get_meshgrid_nd(start, *args, dim=2):
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0,) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = [stop[i] - start[i] for i in range(dim)]
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid

def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[torch.FloatTensor, int],
    theta: float = 10000.0,
    use_real: bool = False,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool, optional): If True, return real part and imaginary part separately.
                                   Otherwise, return complex numbers.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.

    Returns:
        freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor ** (dim / (dim - 2))

    freqs = 1.0 / (
        theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
    )  # [D/2]
    # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
        return freqs_cos, freqs_sin
    else:
        freqs_cis = torch.polar(
            torch.ones_like(freqs), freqs
        )  # complex64     # [S, D/2]
        return freqs_cis


def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    use_real=False,
    theta_rescale_factor: Union[float, List[float]] = 1.0,
    interpolation_factor: Union[float, List[float]] = 1.0,
):
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
            Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
            part and an imaginary part separately.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.

    Returns:
        pos_embed (torch.Tensor): [HW, D/2]
    """

    grid = get_meshgrid_nd(
        start, *args, dim=len(rope_dim_list)
    )  # [3, W, H, D] / [2, W, H]

    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
    assert len(theta_rescale_factor) == len(
        rope_dim_list
    ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"

    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
    assert len(interpolation_factor) == len(
        rope_dim_list
    ), "len(interpolation_factor) should equal to len(rope_dim_list)"

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            use_real=use_real,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    if use_real:
        cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
        sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
        return cos, sin
    else:
        emb = torch.cat(embs, dim=1)  # (WHD, D/2)
        return emb


def get_rotary_pos_embed(video_length, height, width):
    target_ndim = 3
    ndim = 5 - 2
    # 884
    latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]

    assert all(
        s % transformer.config.patch_size[idx] == 0
        for idx, s in enumerate(latents_size)
    ), (
        f"Latent size(last {ndim} dimensions) should be divisible by patch size({transformer.config.patch_size}), "
        f"but got {latents_size}."
    )
    rope_sizes = [
        s // transformer.config.patch_size[idx] for idx, s in enumerate(latents_size)
    ]

    if len(rope_sizes) != target_ndim:
        rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
    head_dim = transformer.config.hidden_size // transformer.config.heads_num
    rope_dim_list = transformer.config.rope_dim_list
    if rope_dim_list is None:
        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
    assert (
        sum(rope_dim_list) == head_dim
    ), "sum(rope_dim_list) should equal to head_dim of attention layer"
    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
        rope_dim_list,
        rope_sizes,
        theta=256,
        use_real=True,
        theta_rescale_factor=1,
    )
    return freqs_cos, freqs_sin


downscale_factor_for_testing = 2
height = 720 // downscale_factor_for_testing // 16 * 16
width = 1280 // downscale_factor_for_testing // 16 * 16
num_frames = 49 # 129
guidance_scale = 1.0
embedded_guidance_scale = 6.0
prompt = "Close-up, A little girl wearing a red hoodie in winter strikes a match. The sky is dark, there is a layer of snow on the ground, and it is still snowing lightly. The flame of the match flickers, illuminating the girl's face intermittently."
negative_prompt = "bad quality, worst quality"
num_inference_steps = 50

freqs_cos, freqs_sin = get_rotary_pos_embed(num_frames, height, width)
output = pipe(
    prompt=prompt,
    height=height,
    width=width,
    video_length=num_frames,
    data_type="video",
    num_inference_steps=num_inference_steps,
    timesteps=None,
    sigmas=None,
    guidance_scale=guidance_scale,
    embedded_guidance_scale=embedded_guidance_scale,
    negative_prompt=negative_prompt,
    num_videos_per_prompt=1,
    generator=torch.Generator().manual_seed(91102),
    output_type="pil",
    freqs_cis=(freqs_cos, freqs_sin),
    enable_tiling=True,
    return_dict=True,
).videos[0]

output = output.permute(1, 2, 3, 0).detach().cpu().numpy()
output = (output * 255).clip(0, 255).astype("uint8")
output = [Image.fromarray(x) for x in output]

export_to_video(output, "output.mp4", fps=8)

@ghunkins
Copy link
Contributor

Thanks for the amazing work @a-r-r-o-w .

I was able to separate the logic out from the TextEncoder so that the original LlamaForCausalLM and CLIPTextModel are used in the pipeline. Code is here: hyvideo.

Additionally, for those looking to run this a bit more sanely:

pip install git+https://github.com/ollanoinc/hyvideo.git

Need to make sure flash-attn is installed as well.

import torch
from hyvideo.diffusion.pipelines.pipeline_hunyuan_video import HunyuanVideoPipeline
from hyvideo.modules.models import HYVideoDiffusionTransformer
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D

pipe = HunyuanVideoPipeline.from_pretrained(
    'magespace/hyvideo-diffusers',
    transformer=HYVideoDiffusionTransformer.from_pretrained(
        'magespace/hyvideo-diffusers',
        torch_dtype=torch.bfloat16,
        subfolder='transformer'
    ),
    vae=AutoencoderKLCausal3D.from_pretrained(
        'magespace/hyvideo-diffusers',
        torch_dtype=torch.bfloat16,
        subfolder='vae'
    ),
    torch_dtype=torch.bfloat16,
)
pipe = pipe.to('cuda')
pipe.vae.enable_tiling()
output.9.mp4

Let me know if you would like me to open a PR.

@a-r-r-o-w
Copy link
Contributor Author

@ghunkins That's awesome! I only started working on the conversion yesterday because was on a short break so apologies for the delay on the thing people are most excited about. It's also quite a slow process to test after every change in order to verify that the outputs match with 0 difference against original, but this should hopefully be close to completion tomorrow given your implementation!

Your implementation separating out the prompt encoding parts helps a lot really! Will use it as reference here and add you as co-author 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @a-r-r-o-w !

return hidden_states, encoder_hidden_states


class PatchEmbed(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to PatchEmbed3D to distinguish it from

class PatchEmbed(nn.Module):

@a-r-r-o-w a-r-r-o-w added roadmap Add to current release roadmap close-to-merge labels Dec 15, 2024
@a-r-r-o-w
Copy link
Contributor Author

a-r-r-o-w commented Dec 16, 2024

TODO: Integration tests. Will take this up in a follow-up PR because I was OOMing on our CI L40s (not on GPU, but on CPU) so need to lower resolution much more or look more deeply into what's causing this.

Also we are good to merge this in from my discussion with YiYi. The hunyuan team will merge diffusers-format weights soon but until then revision="refs/pr/18" is required when loading the pipeline

@a-r-r-o-w a-r-r-o-w merged commit aace1f4 into main Dec 16, 2024
15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the hunyuan-video branch December 16, 2024 08:26
@SHYuanBest
Copy link
Contributor

SHYuanBest commented Dec 16, 2024

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "tencent/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)

As the official weights are not hosted yet, the above code does not still work. We will work with the Hunyuan team to host weights asap. Thanks for being patient!

hunyuan-output.mp4

I use the example code and this weight for testing, but got the black results:

output.mp4

@ghunkins
Copy link
Contributor

@SHYuanBest I haven't tested, but typically this occurs due to the presence of NaNs in the VAE decoding step. You can try using torch.bfloat16 or torch.float32.

# option 1: Use float32 for maximum stability
pipe.vae = pipe.vae.to(torch.float32)

# option 2: Use bfloat16 for a balance of stability and memory efficiency
pipe.vae = pipe.vae.to(torch.bfloat16)

@SHYuanBest
Copy link
Contributor

SHYuanBest commented Dec 17, 2024

@SHYuanBest I haven't tested, but typically this occurs due to the presence of NaNs in the VAE decoding step. You can try using torch.bfloat16 or torch.float32.

# option 1: Use float32 for maximum stability
pipe.vae = pipe.vae.to(torch.float32)

# option 2: Use bfloat16 for a balance of stability and memory efficiency
pipe.vae = pipe.vae.to(torch.bfloat16)

I upgrade the torch to 2.5.1, and fix the error (with default dtype). And it seem that VAE not support bf16: "replication_pad3d_cuda" not implemented for 'BFloat16'

output.mp4

@YoadTew
Copy link

YoadTew commented Dec 17, 2024

I am unable to generate 720p videos due to an out-of-memory exception on an 80GB GPU while using this code:

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "tencent/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16, revision='refs/pr/18'
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, revision='refs/pr/18', torch_dtype=torch.float16)
pipe.vae.enable_tiling()
pipe.to("cuda")
pipe.enable_sequential_cpu_offload()

height, width, num_frames, fps = 720, 1280, 129, 24

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=height, width=width, num_frames=num_frames,
    num_inference_steps=50,
).frames[0]
export_to_video(output, "output.mp4", fps=fps)

When generating with the original repository (https://github.com/Tencent/HunyuanVideo), it works and takes about 30 minutes. Any idea what could be causing the difference?

@hlky hlky mentioned this pull request Dec 17, 2024
@Ednaordinary
Copy link
Contributor

Ednaordinary commented Dec 18, 2024

Here's a version that works well in 24gb vram, up to 720x1280 (only up to around 45 frames at this resolution though)

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, BitsAndBytesConfig
import imageio as iio
import math
import numpy as np
import io
import time

torch.manual_seed(42)

prompt_template = {
    "template": (
        "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
        "1. The main content and theme of the video."
        "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else."
        "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents."
        "4. Background environment, light, style, atmosphere, and qualities."
        "5. Camera angles, movements, and transitions used in the video."
        "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
    ),
    "crop_start": 95,
}

def export_to_video_bytes(fps, frames):
    request = iio.core.Request("<bytes>", mode="w", extension=".mp4")
    pyavobject = iio.plugins.pyav.PyAVPlugin(request)
    if isinstance(frames, np.ndarray):
        frames = (np.array(frames) * 255).astype('uint8')
    else:
        frames = np.array(frames)
    new_bytes = pyavobject.write(frames, codec="libx264", fps=fps)
    out_bytes = io.BytesIO(new_bytes)
    return out_bytes

def export_to_video(frames, path, fps):
    video_bytes = export_to_video_bytes(fps, frames)
    video_bytes.seek(0)
    with open(path, "wb") as f:
        f.write(video_bytes.getbuffer())

model_id = "tencent/HunyuanVideo"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["proj_out", "norm_out"])
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/18", quantization_config=quantization_config
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18")
pipe.scheduler._shift = 7.0
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()

start_time = time.perf_counter()
output = pipe(
    prompt="a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background",
    height = 720,
    width = 1280,
    num_frames = 45,
    prompt_template=prompt_template,
    num_inference_steps = 15,
).frames[0]
export_to_video(output, "output.mp4", fps=15)
print("Time:", round(time.perf_counter() - start_time, 2), "seconds")
print("Max vram:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB")
output.mp4

Above generated in 12 minutes on a 3090 ti. (Only 15 steps in the above video, run 30 for original quality)

NF4 quantization seems to have a much lesser effect than on mochi. Saving and loading pretrained will help a lot here, since the above script quantizes on every run. Also, it seems guidance scale has next to no effect on this model (it seems flux similar, in that guidance scale is likely distilled since 0 guidance does not achieve 2x speed)

@a-r-r-o-w
Copy link
Contributor Author

@YoadTew The original implementation uses flash attention. If you try using the torch variant they provide, it will lead to:

  • OOM as well
  • Bad results because attention mask for padding tokens is not used

You could try testing with the different torch attention backends. I think flash attention backend should work, but I haven't tried it myself yet. https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html

@a-r-r-o-w
Copy link
Contributor Author

@Ednaordinary The model that is currently released is guidance-distilled, meaning it has been trained from another variant of itself (which works with CFG), to try and predict what the outputs would be given a certain CFG scale and empty negative prompt. This allows 2x the speed and half the memory requirements out-of-the-box, since there is no unconditional (negative prompt) latent stream any more to apply CFG to. Instead they take the guidance value you provide and use that as an additional information to condition the model on (similar to Flux-Dev, which is guidance-distilled as well).

Also, awesome work showcasing the use of BnB 🤗

@AshD
Copy link

AshD commented Dec 19, 2024

Is there a way to spread out the transformer's GPU memory across multiple GPUs. I tried with device map - "balanced" but got an error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:2! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* copy transformer

* copy vae

* copy pipeline

* make fix-copies

* refactor; make original code work with diffusers; test latents for comparison generated with this commit

* move rope into pipeline; remove flash attention; refactor

* begin conversion script

* make style

* refactor attention

* refactor

* refactor final layer

* their mlp -> our feedforward

* make style

* add docs

* refactor layer names

* refactor modulation

* cleanup

* refactor norms

* refactor activations

* refactor single blocks attention

* refactor attention processor

* make style

* cleanup a bit

* refactor double transformer block attention

* update mochi attn proc

* use diffusers attention implementation in all modules; checkpoint for all values matching original

* remove helper functions in vae

* refactor upsample

* refactor causal conv

* refactor resnet

* refactor

* refactor

* refactor

* grad checkpointing

* autoencoder test

* fix scaling factor

* refactor clip

* refactor llama text encoding

* add coauthor

Co-Authored-By: "Gregory D. Hunkins" <greg@ollano.com>

* refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device

Note: The following line diverges from original behaviour. We create the grid on the device, whereas
original implementation creates it on CPU and then moves it to device. This results in numerical
differences in layerwise debugging outputs, but visually it is the same.

* use diffusers timesteps embedding; diff: 0.10205078125

* rename

* convert

* update

* add tests for transformer

* add pipeline tests; text encoder 2 is not optional

* fix attention implementation for torch

* add example

* update docs

* update docs

* apply suggestions from review

* refactor vae

* update

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <hlky@hlky.ac>

* make fix-copies

* update

---------

Co-authored-by: "Gregory D. Hunkins" <greg@ollano.com>
Co-authored-by: hlky <hlky@hlky.ac>
@a-r-r-o-w a-r-r-o-w mentioned this pull request Jun 2, 2025
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

close-to-merge roadmap Add to current release roadmap

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants