-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
As of torch==1.8.0 torch.nn.Module.register_backward_hook is deprecated in favor of torch.nn.Module.register_full_backward_hook. In case of a no-op hook, i.e. one that does not modify the gradient in any way, the full backward hook still modifies the input in case it requires a gradient.
To Reproduce
Registering a no-op backward hook on an nn.Module leaves the input as is:
import torch
from torch import nn
input_without_grad = torch.rand(())
input_with_grad = input_without_grad.clone().requires_grad_(True)
class ModuleWithBackwardHook(nn.Module):
def __init__(self):
super().__init__()
self.register_backward_hook(lambda *args: None)
def forward(self, input):
return input
module_with_backward_hook = ModuleWithBackwardHook()
assert (
module_with_backward_hook(input_without_grad) is input_without_grad
), "Module with backward hook modifies a input without grad"
assert (
module_with_backward_hook(input_with_grad) is input_with_grad
), "Module with backward hook modifies a input with grad"By registering a no-op full backward hook, the input is modified in case it requires a gradient:
class ModuleWithFullBackwardHook(nn.Module):
def __init__(self):
super().__init__()
self.register_full_backward_hook(lambda *args: None)
def forward(self, input):
return input
module_with_full_backward_hook = ModuleWithFullBackwardHook()
assert (
module_with_full_backward_hook(input_without_grad) is input_without_grad
), "Module with full backward hook modifies a input without grad"
assert (
module_with_full_backward_hook(input_with_grad) is input_with_grad
), "Module with full backward hook modifies a input with grad"AssertionError: Module with full backward hook modifies a input with grad
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @brianjo @mruberry @jbschlosser
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module