KEMBAR78
Fix HunyuanVideo produces NaN on PyTorch<2.5 by hlky · Pull Request #10482 · huggingface/diffusers · GitHub
Skip to content

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Jan 7, 2025

What does this PR do?

NaN tracked to

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

Specifically, some elements of encoder_hidden_states.

The dimensions of query, key, value and mask are large which suggests versions <2.5 used 32-bit indexing, this tracks with #10314 if ROCm versions are still using 32-bit indexing, this may also close that issue, awaiting confirmation from user.

Tested on CUDA 2.4.1

output.mp4
Code
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
  model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to("cuda")
pipe.vae.enable_tiling()

output = pipe(
  prompt="A cat walks on the grass, realistic",
  height=320,
  width=512,
  num_frames=61,
  num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)

There's also a small performance increase

2.4.1 with fix 2.5.1 2.5.1 with fix
30/30 [01:56<00:00, 3.88s/it] 30/30 [02:04<00:00, 4.16s/it] [01:56<00:00, 3.89s/it]

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @a-r-r-o-w

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Contributor

a-r-r-o-w commented Jan 7, 2025

Oh wow, this is very cool 🤯 It maybe is saving some extra memory too now

Just to confirm, the results before and after are numerically the same, no? Can take a look too if not matched yet 🤗

So, just using a big attention mask is not supported/buggy for < 2.5.1?

@yiyixuxu yiyixuxu merged commit 01bd796 into huggingface:main Jan 7, 2025
12 checks passed
@Nerogar
Copy link
Contributor

Nerogar commented Jan 12, 2025

This change broke batching again. It was previously fixed in #10454

DN6 pushed a commit that referenced this pull request Jan 15, 2025
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants