-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Closed
Copy link
Labels
actionablehigh prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Originally reported by Alex Nichol in slack channel
import torch
import torch.utils.checkpoint
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor):
out_0 = torch.zeros(2**20, device=inp.device, dtype=torch.float32)
out_1 = torch.zeros(2**20, device=inp.device, dtype=torch.float32)
ctx.save_for_backward(
inp,
out_0,
out_1,
)
return out_0, out_1
@staticmethod
def backward(ctx, dA, dB):
_ = ctx.saved_tensors # this is necessary
return None
def op_fn(inp):
return MyOp.apply(inp)[0]
dummy_input = torch.nn.Parameter(torch.randn(2**20, device="cuda"))
for i in range(1000):
full_out = torch.utils.checkpoint.checkpoint(op_fn, dummy_input, use_reentrant=False)
full_out.sum().backward()
dummy_input.grad = None # free gradient memory
print(i, torch.cuda.memory_allocated() / 1024**2, "MiB")This leak happens when a custom autograd Function is the last operation in a checkpointed region (and the autograd Function saves an output for backward). It happens because AC raises an exception to exit the forward early prior to the proper clean up custom autograd Function needs to do when it saves tensors.
Versions
main
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD @gqchen @nikitaved @Varal7 @xmfan
Metadata
Metadata
Assignees
Labels
actionablehigh prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module