KEMBAR78
Fusion causes peak memory increase · Issue #138685 · pytorch/pytorch · GitHub
Skip to content

Fusion causes peak memory increase #138685

@shunting314

Description

@shunting314

🐛 Describe the bug

Here is an example that fusion causes peak memory increases with very marginal perf wins.

import torch
from triton.testing import do_bench

@torch.compile
def f(a, b, c):
    return (a @ c).sum(dim=-1) + (b @ c).sum(dim=-1)

a = torch.randn(1024 * 32, 16, device="cuda")
b = torch.randn(1024 * 32, 16, device="cuda")
c = torch.randn(16, 1024 * 32, device="cuda")

f(a, b, c)
torch.cuda.reset_peak_memory_stats()
ms = do_bench(lambda: f(a, b, c))
print(f"{ms=}")
peak_mem = torch.cuda.max_memory_allocated()
print(f"Peak mem {peak_mem / 1e9:.3f} GB")

By default Inductor will fuse these 2 reductions together. This requires the two large 4GB buffer (a@c) and (b@c) to be alive at the same time, which make the peak memory usage about 8GB. But if we don't fuse these 2 reductions, the buffer for (a@c) and (b@c) can be reused. The peak memory usage would be about 4GB.

A similar pattern is seen when we chunk the linear-cross-entroy-loss.

Metrics with fusion by default:

ms=7.721065044403076
Peak mem 8.886 GB

Metrics by not fusing the two reductions:

ms=7.749417781829834
Peak mem 4.591 GB

.

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions