KEMBAR78
[export] while_loop issues for loop rewrite · Issue #158366 · pytorch/pytorch · GitHub
Skip to content

[export] while_loop issues for loop rewrite #158366

@pianpwk

Description

@pianpwk

🐛 Describe the bug

The first issue shows up when we run decomp on just the while_loop implementation:

from torch._higher_order_ops import while_loop
torch._dynamo.config.capture_scalar_outputs = True

def while_loop_decomp(x, y0):
    out = torch.zeros_like(x)
    def cond_fn(idx, out, y0):
        return idx < out.size(0)

    def body_fn(idx, out, y0):
        i = idx.item()
        torch._check_is_size(i, max=out.size(0))
        y0 = x[i] + y0
        out = out.clone()
        out[i] = y0
        return idx + 1, out, y0

    cnt = torch.full((), 0, dtype=torch.int64)
    _, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0])
    return out

class TestModel(torch.nn.Module):
    def forward(self, x, y0):
        return while_loop_decomp(x, y0)

x, y0 = torch.randn(16, 8), torch.randn(8)
out = TestModel()(x, y0)
ep = export(TestModel(), (x, y0))
ep = ep.run_decompositions()

ERROR (looks like a HOP submodule retracing issue; there's no runtime asserts in the while_loop subgraph to inform the compiler)?

======================================================================
ERROR: test_while_loop_custom_op_strict (caffe2.test.export.test_export_strict.StrictExportTestExport)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/testing.py", line 243, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py", line 14070, in test_while_loop_custom_op
    ep = ep.run_decompositions()  # {torch.ops.mylib.loop.default: while_loop_decomp})
         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/export/exported_program.py", line 1459, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/export/exported_program.py", line 955, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/export/exported_program.py", line 473, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/export/_trace.py", line 868, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/aot_autograd.py", line 1387, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/aot_autograd.py", line 1627, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/aot_autograd.py", line 575, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/aot_autograd.py", line 687, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 198, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 1134, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/interpreter.py", line 173, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7872, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/interpreter.py", line 242, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/interpreter.py", line 322, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/while_loop.py", line 48, in __call__
    return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 524, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 380, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 331, in maybe_run_autograd
    return self(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/while_loop.py", line 48, in __call__
    return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 524, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 416, in dispatch
    result = handler(mode, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 190, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/while_loop.py", line 413, in while_loop_func
    _check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/utils.py", line 393, in _check_alias_and_mutation
    aliases, inp_mutation = has_potential_input_alias_or_mutation(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/utils.py", line 341, in has_potential_input_alias_or_mutation
    ) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/utils.py", line 291, in potential_input_alias_or_mutation
    raise e
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/utils.py", line 285, in potential_input_alias_or_mutation
    gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/utils.py", line 270, in _maybe_fake_tracing
    gm = make_fx(
         ^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 2351, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 2283, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 2254, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_dynamo/eval_frame.py", line 975, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 1283, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_dynamo/eval_frame.py", line 975, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 1341, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 1107, in call_module
    return forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.25 from /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py:1330 in wrapped", line 6, in forward
    select = torch.ops.aten.select.int(arg3_1, 0, item);  arg3_1 = None
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 1389, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 1491, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 2068, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 2723, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_meta_registrations.py", line 5560, in meta_select
    guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 476, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 596, in guard_size_oblivious
    r = self.evaluate(size_oblivious=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7237, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7337, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7360, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7584, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u3 > 16 (unhinted: -u3 > 16).  (Size-like symbols: none)

Caused by: (_meta_registrations.py:5560 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u3"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

While executing %while_loop : [num_users=3] = call_function[target=torch.ops.higher_order.while_loop](args = (%while_loop_cond_graph_0, %while_loop_body_graph_0, (%full, %empty_like, %y0), (%x,)), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x, y0):
        x: "f32[16, 8][8, 1]"; y0: "f32[8][1]"; 
    
        x, y0, = fx_pytree.tree_flatten_spec(([x, y0], {}), self._in_spec)
         # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14046 in while_loop_decomp, code: out = torch.empty_like(x)
        empty_like: "f32[16, 8][8, 1]" = torch.ops.aten.empty_like.default(x, pin_memory = False)
        
         # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14058 in while_loop_decomp, code: cnt = torch.full((), 0, dtype=torch.int64)
        full: "i64[][]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
        
         # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/while_loop.py:144 in while_loop, code: return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
        while_loop_cond_graph_0 = self.while_loop_cond_graph_0
        while_loop_body_graph_0 = self.while_loop_body_graph_0
        while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (full, empty_like, y0), (x,));  while_loop_cond_graph_0 = while_loop_body_graph_0 = full = empty_like = y0 = x = None
        getitem: "i64[][]" = while_loop[0];  getitem = None
        getitem_1: "f32[16, 8][8, 1]" = while_loop[1]
        getitem_2: "f32[8][1]" = while_loop[2];  while_loop = getitem_2 = None
        return pytree.tree_unflatten((getitem_1,), self._out_spec)
        
    class while_loop_cond_graph_0(torch.nn.Module):
        def forward(self, arg0_1: "i64[][]", arg1_1: "f32[16, 8][8, 1]", arg2_1: "f32[8][1]", arg3_1: "f32[16, 8][8, 1]"):
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14048 in cond_fn, code: return idx < out.size(0)
            lt: "b8[][]" = torch.ops.aten.lt.Scalar(arg0_1, 16);  arg0_1 = None
            return lt
            
    class while_loop_body_graph_0(torch.nn.Module):
        def forward(self, arg0_1: "i64[][]", arg1_1: "f32[16, 8][8, 1]", arg2_1: "f32[8][1]", arg3_1: "f32[16, 8][8, 1]"):
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14051 in body_fn, code: i = idx.item()
            item: "Sym(u2)" = torch.ops.aten.item.default(arg0_1)
            
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14053 in body_fn, code: y0 = x[i] + y0
            select: "f32[8][1]" = torch.ops.aten.select.int(arg3_1, 0, item);  arg3_1 = None
            add: "f32[8][1]" = torch.ops.aten.add.Tensor(select, arg2_1);  select = arg2_1 = None
            
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14054 in body_fn, code: out = out.clone()
            clone: "f32[16, 8][8, 1]" = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
            
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14055 in body_fn, code: out[i] = y0
            select_1: "f32[8][1]" = torch.ops.aten.select.int(clone, 0, item);  item = None
            copy_: "f32[8][1]" = torch.ops.aten.copy_.default(select_1, add);  select_1 = copy_ = None
            
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14056 in body_fn, code: return idx + 1, out, y0
            add_1: "i64[][]" = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
            return (add_1, clone, add)
            

Original traceback:
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py", line 14065, in forward
    return while_loop_decomp(x, y0)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py", line 14059, in while_loop_decomp
    _, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0])
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/torch/_higher_order_ops/while_loop.py", line 144, in while_loop
    return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())

The 2nd issue shows up when we change the counter tensor from torch.full((), 0) to torch.tensor(0), and run inference with the exported program (no decomp involved). The program seems to hardcode 0 as the index, and so the output is completely empty save for index 0 (always reads & writes to index 0). Maybe some constant prop issue?

def while_loop_decomp(x, y0):
    out = torch.zeros_like(x)
    def cond_fn(idx, out, y0):
        return idx < out.size(0)

    def body_fn(idx, out, y0):
        i = idx.item()
        torch._check_is_size(i, max=x.size(0) - 1)
        y0 = x[i] + y0
        out = out.clone()
        out[i] = y0
        return idx + 1, out, y0

    cnt = torch.tensor(0)
    _, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0])
    return out

class TestModel(torch.nn.Module):
    def forward(self, x, y0):
        return while_loop_decomp(x, y0)

x, y0 = torch.randn(16, 8), torch.randn(8)
out = TestModel()(x, y0)
ep = export(TestModel(), (x, y0), strict=False)
out = ep.module()(x, y0)
print(ep)
print(out)

OUTPUT:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c_lifted_tensor_0: "i64[]", x: "f32[16, 8]", y0: "f32[8]"):
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/5a50b0f3789073c2/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:14064 in forward, code: return while_loop_decomp(x, y0)
            zeros_like: "f32[16, 8]" = torch.ops.aten.zeros_like.default(x, pin_memory = False)
            lift_fresh_copy: "i64[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            detach_: "i64[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
            
             # File: <eval_with_key>.7:11 in forward, code: while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_args_2_0_, l_args_2_1_, l_args_2_2_), (l_args_1_closure_0_cell_contents_closure_0_cell_contents,));  cond_fn_0 = body_fn_0 = l_args_2_0_ = l_args_2_1_ = l_args_2_2_ = l_args_1_closure_0_cell_contents_closure_0_cell_contents = None
            while_loop_cond_graph_0 = self.while_loop_cond_graph_0
            while_loop_body_graph_0 = self.while_loop_body_graph_0
            while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (detach_, zeros_like, y0), (x,));  while_loop_cond_graph_0 = while_loop_body_graph_0 = detach_ = zeros_like = y0 = x = None
            getitem_2: "i64[]" = while_loop[0];  getitem_2 = None
            getitem_3: "f32[16, 8]" = while_loop[1]
            getitem_4: "f32[8]" = while_loop[2];  while_loop = getitem_4 = None
            return (getitem_3,)
            
        class while_loop_cond_graph_0(torch.nn.Module):
            def forward(self, arg0_1: "i64[]", arg1_1: "f32[16, 8]", arg2_1: "f32[8]", arg3_1: "f32[16, 8]"):
                 # File: <eval_with_key>.4:6 in forward, code: lt = l_args_2_0_.lt(size);  l_args_2_0_ = size = None
                lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, 16);  arg0_1 = None
                return lt
                
        class while_loop_body_graph_0(torch.nn.Module):
            def forward(self, arg0_1: "i64[]", arg1_1: "f32[16, 8]", arg2_1: "f32[8]", arg3_1: "f32[16, 8]"):
                 # File: <eval_with_key>.5:9 in forward, code: select = torch.select(l_args_1_closure_0_cell_contents_closure_0_cell_contents_body_fn, 0, item);  l_args_1_closure_0_cell_contents_closure_0_cell_contents_body_fn = None
                select: "f32[8]" = torch.ops.aten.select.int(arg3_1, 0, 0);  arg3_1 = None
                
                 # File: <eval_with_key>.5:10 in forward, code: y0 = select.add(l_args_2_2_);  select = l_args_2_2_ = None
                add: "f32[8]" = torch.ops.aten.add.Tensor(select, arg2_1);  select = arg2_1 = None
                
                 # File: <eval_with_key>.5:11 in forward, code: out = l_args_2_1_.clone();  l_args_2_1_ = None
                clone: "f32[16, 8]" = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
                
                 # File: <eval_with_key>.5:12 in forward, code: out[item] = y0;  setitem = out;  item = setitem = None
                select_1: "f32[8]" = torch.ops.aten.select.int(clone, 0, 0)
                copy_: "f32[8]" = torch.ops.aten.copy_.default(select_1, add);  select_1 = copy_ = None
                
                 # File: <eval_with_key>.5:13 in forward, code: child = l_args_2_0_.add(1);  l_args_2_0_ = None
                add_1: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
                return (add_1, clone, add)
                
Graph signature: 
    # inputs
    c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
    x: USER_INPUT
    y0: USER_INPUT
    
    # outputs
    getitem_3: USER_OUTPUT
    
Range constraints: {}

tensor([[ -0.8904,  -8.0238,   3.1640, -14.1915,   1.4527,   9.6908,  -2.7418,
          -7.4262],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000]])

Versions

latest in fbcode

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions