KEMBAR78
variable type mismatch when trying compile · Issue #110696 · pytorch/pytorch · GitHub
Skip to content

variable type mismatch when trying compile #110696

@johnnv1

Description

@johnnv1

🐛 Describe the bug

When updating the pytorch on kornia we found a weird issue when trying to compile the GuidedBlur module. The error trace looks like it's some data type issue or related aten.reflection_pad2d, but I don't understand this issue.

Locally, I also get it working for a group of parameters and not for the others:
Working with:

  • kernel_size=5; subsample=1;

not working with:

  • kernel_size=5; subsample=2;
  • kernel_size=(5, 7); subsample=1;
  • kernel_size=(5, 7); subsample=2;

The compile is working before for it, but I wasn't able to reduce it to a minimum reproducible example. I also didn't find related issues.

Error logs

venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328: in _fn
    return fn(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
kornia/filters/guided.py:213: in forward
    return guided_blur(guidance, input, self.kernel_size, self.eps, self.border_type, self.subsample)
kornia/filters/guided.py:161: in guided_blur
    return _guided_blur_multichannel_guidance(guidance, input, kernel_size, eps, border_type, subsample)
kornia/filters/guided.py:75: in _guided_blur_multichannel_guidance
    guidance_sub, input_sub, kernel_size = _preprocess_fast_guided_blur(guidance, input, kernel_size, subsample)
venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:490: in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:641: in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:133: in _fn
    return fn(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:389: in _convert_frame_assert
    return _compile(
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:569: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:491: in compile_inner
    out_code = transform_code_object(code, transform)
venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:1028: in transform_code_object
    transformations(instructions, code_options)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:458: in transform
    tracer.run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2074: in run
    super().run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:724: in run
    and self.step()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:702: in step
    self.output.compile_subgraph(
venv/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:857: in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
venv/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
venv/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:957: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:1024: in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
venv/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:1009: in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
venv/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
venv/lib/python3.10/site-packages/torch/__init__.py:1568: in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:1150: in compile_fx
    return aot_autograd(
venv/lib/python3.10/site-packages/torch/_dynamo/backends/common.py:55: in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:3891: in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:3429: in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2212: in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2392: in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1573: in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:1092: in fw_compiler_base
    return inner_compile(
venv/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py:80: in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
venv/lib/python3.10/site-packages/torch/_inductor/debug.py:228: in inner
    return fn(*args, **kwargs)
venv/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:54: in newFunction
    return old_func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:341: in compile_fx_inner
    compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:550: in fx_codegen_and_compile
    graph.run(*example_inputs)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/graph.py:456: in run
    return super().run(*args)
venv/lib/python3.10/site-packages/torch/fx/interpreter.py:138: in run
    self.env[node] = self.run_node(node)
venv/lib/python3.10/site-packages/torch/_inductor/graph.py:722: in run_node
    result = super().run_node(n)
venv/lib/python3.10/site-packages/torch/fx/interpreter.py:195: in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/graph.py:613: in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
venv/lib/python3.10/site-packages/torch/_inductor/graph.py:610: in call_function
    out = lowerings[target](*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/lowering.py:279: in wrapped
    out = decomp_fn(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_inductor/lowering.py:3115: in reflection_pad2d
    ranges=[*batch, sympy.Integer(h + top + bot), sympy.Integer(w + left + right)],
venv/lib/python3.10/site-packages/sympy/core/cache.py:70: in wrapper
    retval = cfunc(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

cls = <class 'sympy.core.numbers.Integer'>, i = s0 + 3

    @cacheit
    def __new__(cls, i):
        if isinstance(i, str):
            i = i.replace(' ', '')
        # whereas we cannot, in general, make a Rational from an
        # arbitrary expression, we can make an Integer unambiguously
        # (except when a non-integer expression happens to round to
        # an integer). So we proceed by taking int() of the input and
        # let the int routines determine whether the expression can
        # be made into an int or whether an error should be raised.
        try:
            ival = int(i)
        except TypeError:
>           raise TypeError(
                "Argument of Integer should be of numeric type, got %s." % i)
E           torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
E           LoweringException: TypeError: Argument of Integer should be of numeric type, got s0 + 3.
E             target: aten.reflection_pad2d.default
E             args[0]: TensorBox(
E               View(
E                 StorageBox(
E                   Pointwise(
E                     'cpu',
E                     torch.float32,
E                     def inner_fn(index):
E                         i0, i1, i2, i3, i4 = index
E                         tmp0 = ops.load(arg0_1, i4 + 4 * i3 + 16 * i2 + 48 * i0)
E                         tmp1 = ops.load(arg0_1, i4 + 4 * i3 + 16 * i1 + 48 * i0)
E                         tmp2 = tmp0 * tmp1
E                         return tmp2
E                     ,
E                     ranges=[2, 3, 3, 4, 4],
E                     origin_node=mul,
E                     origins={mul}
E                   )
E                 ),
E                 size=[2, 9, 4, 4],
E                 reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 3, 3), ModularIndexing(i1, 1, 3), i2, i3],
E                 origins={mul, view_3}
E               )
E             )
E             args[1]: [((s0 - 1)//2), s0 - ((s0 - 1)//2) - 1, 1, 1]
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

Minified repro

No response

Versions

python 3.10.13
torch 2.1.0

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions