KEMBAR78
`dtype` promotion of `out=` functions on meta inputs not consistent. · Issue #138399 · pytorch/pytorch · GitHub
Skip to content

dtype promotion of out= functions on meta inputs not consistent. #138399

@ysiraichi

Description

@ysiraichi

🐛 Describe the bug

List of operations, whose out= functions on meta inputs are not consistent when ran with real tensors (e.g. CPU or CUDA). Specifically, I changed the data-type of the output tensor from float32 to float64, and checked whether eager with meta and real tensors behave the same.

According to the out= specification, there are some operations that run dtype promotion on the output tensors (given that they are of the same kind), and some that require them to be exactly of the expected dtype. Therefore, using CPU/CUDA inputs as the ground truth, if the behaviors are not the same, it likely means that the meta implementation has a bug.

As an example, aminmax decomposition decorated with type_casts expects the dtypes to be exact. However, in its decomposition (which is used as a meta implementation), it uses the @out_wrapper(...) decorator without specifying exact_dtype=True.

Failed with real inputs, but didn't fail with meta inputs

Didn't fail with real inputs, but failed with meta inputs

Except for mean (which is an actual bug), all the other operations present the same behavior as identified by #138396.

  • geqrf
  • mean
  • nanmean

Dynamic Shape Outputs

Similar to #138396, this operation outputs tensors of dynamic shape. Thus, there's no way to implement a meta function for it.

  • linalg_lstsq

Test Setup

OpInfo Test
import torch
import torch.utils._pytree as pytree
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests, OpDTypes, onlyCUDA, onlyCPU
from torch.testing._internal.common_utils import TestCase, run_tests

class TestCommon(TestCase):

    @ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
    def test_meta_dtype_error_out(self, device, dtype, op):
        samples = list(op.sample_inputs(device, dtype))

        for i, sample in enumerate(samples):
            torch._dynamo.reset()
            input, args, kwargs = (sample.input, sample.args, sample.kwargs)

            # Run the functional version of the operation, using eager.
            try:
                expected = op(input, *args, **kwargs)

                if isinstance(expected, tuple):
                    expected = tuple(expected)
            except:
                # If that doesn't work out, go to the next sample.
                continue

            def run(f, dev):
                # Create new outputs in the desired device.
                out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, device=dev, dtype=torch.float64), expected)

                # Move inputs to the desired device
                stuff = (input, args, kwargs)
                stuff = pytree.tree_map_only(torch.Tensor, lambda t: t.to(dev), stuff)
                stuff = pytree.tree_map_only(torch.device, lambda d: torch.device(dev), stuff)
                stuff = pytree.tree_map_only(str, lambda v: dev if v == device else v, stuff)
                input_, args_, kwargs_ = stuff

                # Try running the operation, and return the raised error, if any.
                try:
                    f(input_, *args_, **kwargs_, out=out)
                except Exception as e:
                    return e

            eager_err = run(op, device)
            meta_err = run(op, "meta")

            if eager_err is None and meta_err is not None:
                raise RuntimeError(f"eager didn't fail, but meta did.") from meta_err
            elif eager_err is not None and meta_err is None:
                raise RuntimeError(f"eager failed, but meta didn't.") from eager_err

instantiate_device_type_tests(TestCommon, globals())

if __name__ == "__main__":
    run_tests()

Versions

PyTorch version: 2.5.0a0+git7128504
Is debug build: True
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

cc @nairbv @mruberry @ezyang @eellison @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: meta tensorsmodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions