-
Notifications
You must be signed in to change notification settings - Fork 6.4k
refactor rotary embedding 3: so it is not on cpu #9307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@yiyixuxu possible to trigger a torch.compile() with PyTorch nightly to verify if this helps with the CUDAGraph issue? Code is in #9299 (comment).
Ccing @cpuhrsch maybe you would like to review it?
|
@sayakpaul yeah tested it is fine |
|
@sayakpaul import torch
import torch.utils.benchmark as benchmark
import gc
import time
torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True
import diffusers
from platform import python_version
from diffusers import DiffusionPipeline
print(diffusers.__version__)
print(torch.__version__)
print(python_version())
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def flush():
"""Wipes off memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")
def run_inference(pipe):
_ = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=5,
guidance_scale=3.5,
max_sequence_length=512,
generator=torch.manual_seed(42),
height=1024,
width=1024,
)
flush()
time = benchmark_fn(run_inference)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
print(f" Execution time: {time} sec")
print(f" Memory: {memory} gib") |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | ||
| t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | ||
| freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] | ||
| freqs = freqs.to(pos.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect this to cause a sync as well since by default arange allocates on the CPU. One way to mitigate could be to
a) use pin_memory() on freqs ahead of time and set non_blocking=True
b) do arange on the GPU right away (i.e. torch.arange([...], device=pos.device)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhh let's do torch.arange([...], device=pos.device)
|
@yiyixuxu that looks reasonable but I'd call |
|
|
||
| if isinstance(pos, int): | ||
| pos = np.arange(pos) | ||
| pos = torch.arange(pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be passed a device argument to allocate it on the GPU. If this isn't on the GPU, then neither will the following Tensors.
change get_1d_rotary to accept pos as torch tensors
fix #9299