-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchtensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
import torch
from torch.testing._internal.two_tensor import TwoTensor
torch.manual_seed(123)
class Log(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
if TwoTensor in types:
return NotImplemented
res = func(*args, **kwargs)
if func is torch.ops.aten.mul.Tensor:
print(f"\n{func}(*{args}, **{kwargs})")
print(res)
return res
a = torch.tensor([1+1j], dtype=torch.complex64, device="cuda", requires_grad=True)
b = torch.randn((1,), dtype=torch.complex64, device="cuda")
with Log():
a * b.conj()
a = TwoTensor(a, a.clone())
b = TwoTensor(b, b.clone())
with Log():
a * b.conj()Leads to:
aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', requires_grad=True), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([1.0920+0.8018j], device='cuda:0')
aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', requires_grad=True), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([0.8018+1.0920j], device='cuda:0')
aten.mul.Tensor(*(tensor([1.+1.j], device='cuda:0', grad_fn=<CloneBackward0>), tensor([0.9469-0.1451j], device='cuda:0')), **{})
tensor([0.8018+1.0920j], device='cuda:0')
First result is with plain Tensor and the next two are the two mul computation after the TwoTensor.
As you can see the result of the mul within the two Tensor is the result with the second argument being conjugated. So the Conj bit has not been handled properly somehow.
The current workaround for the user is to force resolving the conj so that lazy conj handling cannot break things by adding the following to the logging mode above:
if func is torch.ops.aten._conj.default:
res = res.resolve_conj()
cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @amjames
tianyu-l, awgu and weifengpy
Metadata
Metadata
Assignees
Labels
module: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchtensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module