Reparameterize your PyTorch modules
- PyTorch
>= 1.2.0 - Python 3
import torch
import torch.nn.functional as F
import torchvision
from torchreparam import ReparamModule
device = torch.device('cuda')
# A regular network
net = torchvision.models.resnet18().to(device)
# Reparametrize it!
reparam_net = ReparamModule(net)
print(f"reparam_net has {reparam_net.param_numel} parameters")
assert tuple(reparam_net.parameters()) == (reparam_net.flat_param,)
print(f"reparam_net now has **only one** vector parameter of shape {reparam_net.flat_param.shape}")
# The reparametrized module is equivalent with the original one.
# In fact, the weights share storage.
dummy_input_image = torch.randn(1, 3, 224, 224, device=device)
print(f'original net output for class 746: {net(dummy_input_image)[0, 746]}')
print(f'reparam_net output for class 746: {reparam_net(dummy_input_image)[0, 746]}')
# We can optionally trace the forward method with PyTorch JIT so it runs faster.
# To do so, we can call `.trace` on the reparamtrized module with dummy inputs
# expected by the module.
# Comment out this following line if you do not want to trace.
reparam_net = reparam_net.trace(dummy_input_image)
# Example on a MAML loss that
# 1. Train `theta_0` on `inner_train` for `num_gd_steps` gradient descent steps with `lr`.
# 2. Compute the loss of the updated parameter on `inner_val`.
#
# This assumes classification with cross entropy loss, but can be easily adapted
# to other loss functions.
def maml_loss(reparam_net, theta_0, inner_train, inner_val, num_gd_steps=5, lr=0.01):
# train stage
train_data, train_label = inner_train
theta = theta_0
for _ in range(num_gd_steps):
# perform GD update on (data, label) w.r.t. theta
loss = F.cross_entropy(reparam_net(train_data, flat_param=theta), train_label)
gtheta, = torch.autograd.grad(loss, theta, create_graph=True) # create_graph=True for backprop through this
# update
theta = theta - lr * gtheta
# val stage
# theta is now the final updated set of parameters
val_data, val_label = inner_val
return F.cross_entropy(reparam_net(val_data, flat_param=theta), val_label)
# Let's use the above function:
# Initialize our theta_0 that we want to train
theta_0 = torch.randn_like(reparam_net.flat_param).mul_(0.001).requires_grad_()
# Make dummy data
inner_train = (
torch.randn(2, 3, 224, 224, device=device), # input
torch.randint(low=0, high=1000, size=(2,), device=device), # label
)
inner_val = (
torch.randn(5, 3, 224, 224, device=device), # input
torch.randint(low=0, high=1000, size=(5,), device=device), # label
)
l = maml_loss(reparam_net, theta_0, inner_train, inner_val)
l.backward()
# Here, easily backprop-able by autograd.
print(f'MAML loss gradient for theta_0:\n{theta_0.grad}')python setup.py installFor a ReparamModule, the following fields are available:
.flat_param: a flattened parameter vector representing all parameteres of the wrapped module..param_numel: the total number of parameters, i.e., the size of.flat_param.
A ReparamModule can be called with the following signatire:
reparam_module(self, *inputs, flat_param=None, buffers=None)where
inputswill be passed over as inputs to the inner module.flat_paramwill be used as the parameter of this forward pass, if specified. Note that this allows you easily activate a network on an entirely different set of parameters, and backprop to them.bufferswill be used as the buffers for this forward pass, if specified (experimental).
Note
ReparamModulecurrently does not work properly with Batch Normalization layers with the defaulttrack_running_stats=True.