-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
import torch
import numpy
from contextlib import AbstractContextManager
c = torch.compile
dev = 'cuda:0'
class expected(AbstractContextManager):
def __init__(self, expected_exception_cls=None, subclass=False):
self.expected = expected_exception_cls
self.accept_subclass = subclass
def __exit__(self, exc_type, exc_value, traceback):
if self.expected is not None:
assert exc_type is not None, 'Expected exception not raised'
if issubclass(exc_type, self.expected) if self.accept_subclass else exc_type == self.expected:
return True
return False
def foo(a: numpy.ndarray, b: torch.Tensor):
a = b.new_tensor(a)
return torch.cat([a, b], dim=-1)
foo(
numpy.array([ 1 ]),
torch.randint(0, 10, [1], device=dev),
)
with expected(torch._dynamo.exc.TorchRuntimeError):
c(foo)(
numpy.array([ 1 ]),
torch.randint(0, 10, [1], device=dev),
)
with expected(RuntimeError):
foo(
torch.randint(0, 10, [1]),
torch.randint(0, 10, [1], device=dev),
)
with expected(torch._dynamo.exc.TorchRuntimeError):
c(foo)(
torch.randint(0, 10, [1]),
torch.randint(0, 10, [1], device=dev),
)There's 4 calls here: {tensor.new_tensor(ndarray),tensor.new_tensor(tensor)} with {eager/dynamo}, only the first one works without raising an exception.
#73838 said that tensor_a.new_tensor(tensor_b) returning an tensor on tensor_b.device is a side-effect, not an intentional change.
With torch.compile, ndarrays are converted to FakeTensor and treated like tensor, not data. So the device will be the default value cpu, not tensor_a.device, causing the inconsistency here.
This cause an issue when we tried to compile an existing model that relies on the documented behaviour, i.e. tensor_a.new_tensor(tensor_b) should an tensor on tensor_a.device.
While there're other possible fixes for this like handle this case in dynamo, I'd prefer to get back to the documented behaviour: #144958
Versions
verified on main, unrelated to envs.
cc @mruberry @rgommers @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames