KEMBAR78
nn.DataParallel ignores requires_grad setting when running · Issue #5041 · pytorch/pytorch · GitHub
Skip to content

nn.DataParallel ignores requires_grad setting when running #5041

@jhlee525

Description

@jhlee525
  • PyTorch version: 0.3.0
  • Python version: 3.5
  • GPU models and configuration: 8x NVIDIA Titan

I found that requires_grad setting in module is ignored when module is wrapped with nn.DataParallel. For example,

import torch
from torch.nn import DataParallel
from torchvision.models.resnet import resnet50

module_ = resnet50()
for name, param in module_.named_parameters():
    if name.startswith('conv1') or name.startswith('bn1'):
        param.requires_grad = False
    if name.startswith('layer1') or name.startswith('layer2'):
        param.requires_grad = False
    if name.startswith('layer3') or name.startswith('layer4'):
        param.requires_grad = False

module_ = DataParallel(module_).cuda()

x = torch.rand(32, 3, 1600, 1600)
x = torch.autograd.Variable(x)
x = module_(x)
print(x)

This code should be executed because very small part of network activation (only fc layer) is stored for backward computation, but it results to run-time error:

RuntimeError: cuda runtime error (2) : out of memory at /tmp/pip-8pfswvat-build/torch/lib/THC/generic/THCStorage.cu:58

Internally, I found a crack in replicate function which is in torch.nn.parallel.replicate. In replicate function, it copies all parameter in module (# of replica times) with Broadcast.apply. In broadcasting code, it just defines new torch.nn.Parameter with default constructor requires_grad parameter, which is always set to True.

I think there can be some choices to fix this issue.

  1. It is intended behavior for DataParallel, so we should use volatile (or torch.no_grad) to implement transfer learning in DataParallel
  2. Broadcast model should be fixed to handle requires_grad when copying parameters.
  3. There should be synchronizing requires_grad in torch.nn.parallel.replicate

I think any of solution doesn't need lots of effort to fix it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions