KEMBAR78
Batched torch.matmul fails assert for differentiable inputs · Issue #116099 · pytorch/pytorch · GitHub
Skip to content

Batched torch.matmul fails assert for differentiable inputs #116099

@lw

Description

@lw

🐛 Describe the bug

In [1]: import torch

In [2]: a = torch.empty((256, 512), requires_grad=True).unsqueeze(0)

In [3]: b = torch.empty((4, 128, 512), requires_grad=True).transpose(-1, -2)

In [4]: c = torch.empty((256, 4, 128)).movedim(1, 0)

In [5]: _ = torch.matmul(a.detach(), b.detach(), out=c)

In [6]: _ = torch.matmul(a, b, out=c)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 1
----> 1 _ = torch.matmul(a, b, out=c)

RuntimeError: !(transpose && t2_is_matrix) INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1695392035629/work/aten/src/ATen/native/LinearAlgebra.cpp":2026, please report a bug to PyTorch.

Versions

PyTorch 2.1.0 py3.9_cuda12.1_cudnn8.9.2_0 from conda

cc @ezyang @gchanan @zou3519 @kadeng @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

Metadata

Metadata

Assignees

Labels

actionablehigh prioritymodule: assert failureThe issue involves an assert failuremodule: autogradRelated to torch.autograd, and the autograd engine in generaltriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions