KEMBAR78
torch.cat doesn't do proper shape checking, which can lead to incorrect gradients · Issue #4071 · pytorch/pytorch · GitHub
Skip to content

torch.cat doesn't do proper shape checking, which can lead to incorrect gradients #4071

@stefdoerr

Description

@stefdoerr

This is a very sneaky behaviour so I will call it a bug as it can easily break your code.
I paste here a self-contained example. Even though axes1 is equivalent to axes2, axes2 gives wrong gradients.

from torch.autograd import Variable
from torch import autograd
import torch
import numpy as np

def norm_vec(v):
    return torch.sqrt(torch.pow(v, 2).sum())

coo = Variable(torch.from_numpy(np.array([[3, 4, 2],])).float(), requires_grad=True)
axes1 = torch.stack([coo[0] / norm_vec(coo[0]), Variable(torch.zeros(3), requires_grad=True), Variable(torch.zeros(3), requires_grad=True)], dim=1)
axes2 = torch.stack([coo[0] / norm_vec(coo[0]), Variable(torch.zeros(3, 1), requires_grad=True), Variable(torch.zeros(3, 1), requires_grad=True)], dim=1).squeeze()
new1 = torch.matmul(coo, axes1)
new2 = torch.matmul(coo, axes2)
print(autograd.grad(new1[0, 0], coo, create_graph=True), autograd.grad(new2[0, 0], coo, create_graph=True))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions