KEMBAR78
No-op full backward hooks modify the input · Issue #61446 · pytorch/pytorch · GitHub
Skip to content

No-op full backward hooks modify the input #61446

@pmeier

Description

@pmeier

🐛 Bug

As of torch==1.8.0 torch.nn.Module.register_backward_hook is deprecated in favor of torch.nn.Module.register_full_backward_hook. In case of a no-op hook, i.e. one that does not modify the gradient in any way, the full backward hook still modifies the input in case it requires a gradient.

To Reproduce

Registering a no-op backward hook on an nn.Module leaves the input as is:

import torch
from torch import nn

input_without_grad = torch.rand(())
input_with_grad = input_without_grad.clone().requires_grad_(True)


class ModuleWithBackwardHook(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_backward_hook(lambda *args: None)

    def forward(self, input):
        return input


module_with_backward_hook = ModuleWithBackwardHook()
assert (
    module_with_backward_hook(input_without_grad) is input_without_grad
), "Module with backward hook modifies a input without grad"
assert (
    module_with_backward_hook(input_with_grad) is input_with_grad
), "Module with backward hook modifies a input with grad"

By registering a no-op full backward hook, the input is modified in case it requires a gradient:

class ModuleWithFullBackwardHook(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_full_backward_hook(lambda *args: None)

    def forward(self, input):
        return input


module_with_full_backward_hook = ModuleWithFullBackwardHook()
assert (
    module_with_full_backward_hook(input_without_grad) is input_without_grad
), "Module with full backward hook modifies a input without grad"
assert (
    module_with_full_backward_hook(input_with_grad) is input_with_grad
), "Module with full backward hook modifies a input with grad"
AssertionError: Module with full backward hook modifies a input with grad

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @brianjo @mruberry @jbschlosser

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: docsRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions