This is a very sneaky behaviour so I will call it a bug as it can easily break your code.
I paste here a self-contained example. Even though axes1
is equivalent to axes2
, axes2
gives wrong gradients.
from torch.autograd import Variable
from torch import autograd
import torch
import numpy as np
def norm_vec(v):
return torch.sqrt(torch.pow(v, 2).sum())
coo = Variable(torch.from_numpy(np.array([[3, 4, 2],])).float(), requires_grad=True)
axes1 = torch.stack([coo[0] / norm_vec(coo[0]), Variable(torch.zeros(3), requires_grad=True), Variable(torch.zeros(3), requires_grad=True)], dim=1)
axes2 = torch.stack([coo[0] / norm_vec(coo[0]), Variable(torch.zeros(3, 1), requires_grad=True), Variable(torch.zeros(3, 1), requires_grad=True)], dim=1).squeeze()
new1 = torch.matmul(coo, axes1)
new2 = torch.matmul(coo, axes2)
print(autograd.grad(new1[0, 0], coo, create_graph=True), autograd.grad(new2[0, 0], coo, create_graph=True))