-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Here is an example that fusion causes peak memory increases with very marginal perf wins.
import torch
from triton.testing import do_bench
@torch.compile
def f(a, b, c):
return (a @ c).sum(dim=-1) + (b @ c).sum(dim=-1)
a = torch.randn(1024 * 32, 16, device="cuda")
b = torch.randn(1024 * 32, 16, device="cuda")
c = torch.randn(16, 1024 * 32, device="cuda")
f(a, b, c)
torch.cuda.reset_peak_memory_stats()
ms = do_bench(lambda: f(a, b, c))
print(f"{ms=}")
peak_mem = torch.cuda.max_memory_allocated()
print(f"Peak mem {peak_mem / 1e9:.3f} GB")
By default Inductor will fuse these 2 reductions together. This requires the two large 4GB buffer (a@c) and (b@c) to be alive at the same time, which make the peak memory usage about 8GB. But if we don't fuse these 2 reductions, the buffer for (a@c) and (b@c) can be reused. The peak memory usage would be about 4GB.
A similar pattern is seen when we chunk the linear-cross-entroy-loss.
Metrics with fusion by default:
ms=7.721065044403076
Peak mem 8.886 GB
Metrics by not fusing the two reductions:
ms=7.749417781829834
Peak mem 4.591 GB
.
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov