-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
List of operations, whose out= functions on meta inputs are not consistent when ran with real tensors (e.g. CPU or CUDA). Specifically, I changed the data-type of the output tensor from float32 to float64, and checked whether eager with meta and real tensors behave the same.
According to the out= specification, there are some operations that run dtype promotion on the output tensors (given that they are of the same kind), and some that require them to be exactly of the expected dtype. Therefore, using CPU/CUDA inputs as the ground truth, if the behaviors are not the same, it likely means that the meta implementation has a bug.
As an example, aminmax decomposition decorated with type_casts expects the dtypes to be exact. However, in its decomposition (which is used as a meta implementation), it uses the @out_wrapper(...) decorator without specifying exact_dtype=True.
Failed with real inputs, but didn't fail with meta inputs
-
abs: Fix unary references' out dtype check. #140288 -
addbmm -
addmm: error on output dtype mismatch. #138520 -
addmv -
alias_copy -
all -
amax -
amin -
aminmax -
any -
as_strided_copy -
baddbmm -
bucketize -
ceil: Fix unary references' out dtype check. #140288 -
conj_physical -
cross -
cummax -
cummin -
diag -
diagonal_copy -
dot -
expand_copy -
fft_ihfft2 -
fft_ihfftn -
floor: Fix unary references' out dtype check. #140288 -
frac: Fix unary references' out dtype check. #140288 -
frexp -
heaviside -
index_add -
index_copy -
index_select -
isin -
isneginf: Fix unary references' out dtype check. #140288 -
isposinf: Fix unary references' out dtype check. #140288 -
kthvalue -
lerp -
linalg_cross -
linalg_eigh -
linalg_eigvalsh -
linalg_ldl_factor -
linalg_ldl_factor_ex -
linalg_ldl_solve -
linalg_lu -
linalg_lu_factor -
linalg_lu_factor_ex -
linalg_lu_solve -
linalg_matrix_power -
linalg_qr -
linalg_slogdet -
linalg_solve -
linalg_solve_ex -
linalg_solve_triangular -
log_softmax: fix meta function output argument dtype check. #140289 -
logcumsumexp -
lu_solve -
lu_unpack -
matmul -
max_reduction_no_dim -
min_reduction_no_dim -
mm -
mode -
msort -
multinomial -
mv -
nan_to_num -
narrow_copy -
native_batch_norm -
neg -
nn_functional_avg_pool3d -
nn_functional_gelu -
nn_functional_hardshrink -
nn_functional_linear -
nn_functional_logsigmoid -
nn_functional_softplus -
nn_functional_softshrink -
ormqr -
pow: fix meta function output argument dtype check. #140287 -
qr -
renorm -
round -
round_decimals_0 -
scatter_reduce_amax -
scatter_reduce_amin -
scatter_reduce_mean -
scatter_reduce_prod -
scatter_reduce_sum -
searchsorted -
sgn: Fix unary references' out dtype check. #140288 -
sign: Fix unary references' out dtype check. #140288 -
signbit: Fix unary references' out dtype check. #140288 -
slice_scatter -
softmax -
sort -
sparse_sampled_addmm -
square:pow: fix meta function output argument dtype check. #140287 -
squeeze_copy -
t_copy -
take -
transpose_copy -
tril -
triangular_solve: fix meta function output argument dtype check. #140286 -
triu -
trunc: Fix unary references' out dtype check. #140288 -
unfold_copy -
unsqueeze_copy -
vdot -
view_copy -
where
Didn't fail with real inputs, but failed with meta inputs
Except for mean (which is an actual bug), all the other operations present the same behavior as identified by #138396.
-
geqrf -
mean -
nanmean
Dynamic Shape Outputs
Similar to #138396, this operation outputs tensors of dynamic shape. Thus, there's no way to implement a meta function for it.
linalg_lstsq
Test Setup
OpInfo Test
import torch
import torch.utils._pytree as pytree
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests, OpDTypes, onlyCUDA, onlyCPU
from torch.testing._internal.common_utils import TestCase, run_tests
class TestCommon(TestCase):
@ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
def test_meta_dtype_error_out(self, device, dtype, op):
samples = list(op.sample_inputs(device, dtype))
for i, sample in enumerate(samples):
torch._dynamo.reset()
input, args, kwargs = (sample.input, sample.args, sample.kwargs)
# Run the functional version of the operation, using eager.
try:
expected = op(input, *args, **kwargs)
if isinstance(expected, tuple):
expected = tuple(expected)
except:
# If that doesn't work out, go to the next sample.
continue
def run(f, dev):
# Create new outputs in the desired device.
out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, device=dev, dtype=torch.float64), expected)
# Move inputs to the desired device
stuff = (input, args, kwargs)
stuff = pytree.tree_map_only(torch.Tensor, lambda t: t.to(dev), stuff)
stuff = pytree.tree_map_only(torch.device, lambda d: torch.device(dev), stuff)
stuff = pytree.tree_map_only(str, lambda v: dev if v == device else v, stuff)
input_, args_, kwargs_ = stuff
# Try running the operation, and return the raised error, if any.
try:
f(input_, *args_, **kwargs_, out=out)
except Exception as e:
return e
eager_err = run(op, device)
meta_err = run(op, "meta")
if eager_err is None and meta_err is not None:
raise RuntimeError(f"eager didn't fail, but meta did.") from meta_err
elif eager_err is not None and meta_err is None:
raise RuntimeError(f"eager failed, but meta didn't.") from eager_err
instantiate_device_type_tests(TestCommon, globals())
if __name__ == "__main__":
run_tests()Versions
PyTorch version: 2.5.0a0+git7128504
Is debug build: True
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A