KEMBAR78
Memory leak when custom autograd Function is used with AC due to early-stopping · Issue #161186 · pytorch/pytorch · GitHub
Skip to content

Memory leak when custom autograd Function is used with AC due to early-stopping #161186

@soulitzer

Description

@soulitzer

🐛 Describe the bug

Originally reported by Alex Nichol in slack channel

import torch
import torch.utils.checkpoint

class MyOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp: torch.Tensor):
        out_0 = torch.zeros(2**20, device=inp.device, dtype=torch.float32)
        out_1 = torch.zeros(2**20, device=inp.device, dtype=torch.float32)
        ctx.save_for_backward(
            inp,
            out_0,
            out_1,
        )
        return out_0, out_1

    @staticmethod
    def backward(ctx, dA, dB):
        _ = ctx.saved_tensors  # this is necessary
        return None


def op_fn(inp):
    return MyOp.apply(inp)[0]


dummy_input = torch.nn.Parameter(torch.randn(2**20, device="cuda"))
for i in range(1000):
    full_out = torch.utils.checkpoint.checkpoint(op_fn, dummy_input, use_reentrant=False)
    full_out.sum().backward()
    dummy_input.grad = None  # free gradient memory
    print(i, torch.cuda.memory_allocated() / 1024**2, "MiB")

This leak happens when a custom autograd Function is the last operation in a checkpointed region (and the autograd Function saves an output for backward). It happens because AC raises an exception to exit the forward early prior to the proper clean up custom autograd Function needs to do when it saves tensors.

Versions

main

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD @gqchen @nikitaved @Varal7 @xmfan

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablehigh prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions