KEMBAR78
Conjugate bit not handled properly in wrapped subclasses · Issue #130646 · pytorch/pytorch · GitHub
Skip to content

Conjugate bit not handled properly in wrapped subclasses #130646

@albanD

Description

@albanD
import torch
from torch.testing._internal.two_tensor import TwoTensor
torch.manual_seed(123)

class Log(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        if TwoTensor in types:
            return NotImplemented
        res = func(*args, **kwargs)
        if func is torch.ops.aten.mul.Tensor:
            print(f"\n{func}(*{args}, **{kwargs})")
            print(res)
        return res

a = torch.tensor([1+1j], dtype=torch.complex64, device="cuda", requires_grad=True)
b = torch.randn((1,), dtype=torch.complex64, device="cuda")

with Log():
    a * b.conj()

a = TwoTensor(a, a.clone())
b = TwoTensor(b, b.clone())

with Log():
    a * b.conj()

Leads to:

aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', requires_grad=True), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([1.0920+0.8018j], device='cuda:0')

aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', requires_grad=True), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([0.8018+1.0920j], device='cuda:0')

aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', grad_fn=<CloneBackward0>), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([0.8018+1.0920j], device='cuda:0')

First result is with plain Tensor and the next two are the two mul computation after the TwoTensor.

As you can see the result of the mul within the two Tensor is the result with the second argument being conjugated. So the Conj bit has not been handled properly somehow.

The current workaround for the user is to force resolving the conj so that lazy conj handling cannot break things by adding the following to the logging mode above:

        if func is torch.ops.aten._conj.default:
            res = res.resolve_conj()

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: complexRelated to complex number support in PyTorchtensor subclassRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions