KEMBAR78
failling to compile function with matmul · Issue #110680 · pytorch/pytorch · GitHub
Skip to content

failling to compile function with matmul #110680

@johnnv1

Description

@johnnv1

🐛 Describe the bug

When updating the pytorch on kornia we found an issue within @/matmul when trying to compile the center crop 3d function

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function matmul>(*(FakeTensor(..., size=(4, 4, 4)), FakeTensor(..., size=(1, 4, 4))), **{}):

The compile is working before for it, but I wasn't able to reduce it to a minimum reproducible example. I also didn't found related issues.

Error logs

kornia/geometry/transform/crop3d.py:223: in center_crop3d
    return crop_by_boxes3d(
kornia/geometry/transform/crop3d.py:301: in crop_by_boxes3d
    validate_bbox3d(src_box)
kornia/geometry/transform/crop3d.py:302: in <resume in crop_by_boxes3d>
    validate_bbox3d(dst_box)
kornia/geometry/transform/crop3d.py:313: in <resume in crop_by_boxes3d>
    bbox = infer_bbox_shape3d(dst_box)
kornia/geometry/transform/crop3d.py:313: in <resume in crop_by_boxes3d>
    bbox = infer_bbox_shape3d(dst_box)
kornia/geometry/transform/crop3d.py:314: in <resume in crop_by_boxes3d>
    if not ((bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all() and (bbox[2] == bbox[2][0]).all()):
kornia/geometry/transform/crop3d.py:314: in <resume in crop_by_boxes3d>
    if not ((bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all() and (bbox[2] == bbox[2][0]).all()):
kornia/geometry/transform/crop3d.py:323: in <resume in crop_by_boxes3d>
    (int(bbox[0][0].item()), int(bbox[1][0].item()), int(bbox[2][0].item())),
kornia/geometry/transform/crop3d.py:323: in <resume in crop_by_boxes3d>
    (int(bbox[0][0].item()), int(bbox[1][0].item()), int(bbox[2][0].item())),
kornia/geometry/transform/crop3d.py:323: in <resume in crop_by_boxes3d>
    (int(bbox[0][0].item()), int(bbox[1][0].item()), int(bbox[2][0].item())),
venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:490: in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:641: in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:133: in _fn
    return fn(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:389: in _convert_frame_assert
    return _compile(
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:569: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:189: in time_wrapper
    r = func(*args, **kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:491: in compile_inner
    out_code = transform_code_object(code, transform)
venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:1028: in transform_code_object
    transformations(instructions, code_options)
venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:458: in transform
    tracer.run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2074: in run
    super().run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:724: in run
    and self.step()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688: in step
    getattr(self, inst.opname)(inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:392: in wrapper
    return inner_fn(self, inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1167: in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:562: in call_function
    self.push(fn.call_function(self, args, kwargs))
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:261: in call_function
    return super().call_function(tx, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:90: in call_function
    return tx.inline_user_function_return(
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:598: in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2179: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2286: in inline_call_
    tracer.run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:724: in run
    and self.step()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688: in step
    getattr(self, inst.opname)(inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:392: in wrapper
    return inner_fn(self, inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1167: in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:562: in call_function
    self.push(fn.call_function(self, args, kwargs))
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:261: in call_function
    return super().call_function(tx, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:90: in call_function
    return tx.inline_user_function_return(
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:598: in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2179: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2286: in inline_call_
    tracer.run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:724: in run
    and self.step()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688: in step
    getattr(self, inst.opname)(inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:392: in wrapper
    return inner_fn(self, inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1115: in CALL_FUNCTION
    self.call_function(fn, args, {})
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:562: in call_function
    self.push(fn.call_function(self, args, kwargs))
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:261: in call_function
    return super().call_function(tx, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:90: in call_function
    return tx.inline_user_function_return(
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:598: in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2179: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2286: in inline_call_
    tracer.run()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:724: in run
    and self.step()
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688: in step
    getattr(self, inst.opname)(inst)
venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:168: in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py:570: in call_function
    return wrap_fx_proxy(tx, proxy, **options)
venv/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:1187: in wrap_fx_proxy
    return wrap_fx_proxy_cls(
venv/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:1274: in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1376: in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1337: in get_fake_value
    return wrap_fake_exception(
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:916: in wrap_fake_exception
    return fn()
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1338: in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1410: in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tracer = <torch._dynamo.output_graph.OutputGraph object at 0x7fa8a002bc10>
node = matmul
args = (FakeTensor(..., size=(4, 4, 4)), FakeTensor(..., size=(1, 4, 4)))
kwargs = {}, nnmodule = None

    def run_node(tracer, node, args, kwargs, nnmodule):
        """
        Runs a given node, with the given args and kwargs.
    
        Behavior is dicatated by a node's op.
    
        run_node is useful for extracting real values out of nodes.
        See get_real_value for more info on common usage.
    
        Note: The tracer arg is only used for 'get_attr' ops
        Note: The nnmodule arg is only used for 'call_module' ops
    
        Nodes that are not call_function, call_method, call_module, or get_attr will
        raise an AssertionError.
        """
        op = node.op
        try:
            if op == "call_function":
>               return node.target(*args, **kwargs)
E               torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function matmul>(*(FakeTensor(..., size=(4, 4, 4)), FakeTensor(..., size=(1, 4, 4))), **{}):
E               unsupported operand type(s) for @: 'FakeTensor' and 'FakeTensor'
E               
E               from user code:
E                  File "/tmp/kornia/kornia/geometry/transform/crop3d.py", line 320, in <resume in crop_by_boxes3d>
E                   patches: torch.Tensor = crop_by_transform_mat3d(
E                 File "/tmp/kornia/kornia/geometry/transform/crop3d.py", line 357, in crop_by_transform_mat3d
E                   patches: torch.Tensor = warp_affine3d(
E                 File "/tmp/kornia/kornia/geometry/transform/imgwarp.py", line 867, in warp_affine3d
E                   dst_norm_trans_src_norm: Tensor = normalize_homography3d(M_4x4, size_src, size_out)  # Bx4x4
E                 File "/tmp/kornia/kornia/geometry/conversions.py", line 1149, in normalize_homography3d
E                   dst_norm_trans_src_norm: Tensor = dst_norm_trans_dst_pix @ (dst_pix_trans_src_pix @ src_pix_trans_src_norm)

Minified repro

No response

Versions

python 3.10.13
torch 2.1.0 (git version: 7bcf7da)
ubuntu 22.04

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions