KEMBAR78
Better error handling for cond by ydwu4 · Pull Request #108817 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Sep 8, 2023

Stack from ghstack (oldest at bottom):

Exception in cond:

For code below:

import torch
import functorch.experimental.control_flow as control_flow
def true_fn(x):
    return x.sin()

def false_fn(x):
    return x, x

def f(x, y):
    return control_flow.cond(y, true_fn, false_fn, [x])

f(torch.ones(3, 4), torch.tensor(False))

The original exception stack trace is:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1164, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 418, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Expected branch to return a single tensor

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

After this PR we get:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 429, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 421, in speculate_branch
    unimplemented(
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 187, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Expected branch to return a single tensor

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1159, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Exception during speculating branches

The example code below has a inplace-buffer mutation error,

import torch
import functorch.experimental.control_flow as control_flow

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.ones(6, 4))

    def forward(self, x):
        def true_fn(x):
            self.buffer += 1
            return self.buffer.sum() + x.sum()

        def false_fn(x):
            return (x - 1).sum()

        return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])

mod_for_compile = torch.compile(Foo(), backend="eager", dynamic=True)
mod_for_compile(torch.ones(3, 4))

Before this PR the exception looks like:

[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 163, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 394, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 222, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 415, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 405, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

after this PR, the only difference is the error message of UncapturedHigherOrderOpError changes from Cond doesn't work unless it is captured completely with torch.compile to Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 177, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1214, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 426, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 236, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108817

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cbef357 with merge base c657d9e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ydwu4 added a commit that referenced this pull request Sep 8, 2023
ghstack-source-id: 780acd8
Pull Request resolved: #108817
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Sep 8, 2023
ghstack-source-id: 97b035a
Pull Request resolved: #108817
@ydwu4 ydwu4 requested a review from zou3519 September 8, 2023 22:24
@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 12, 2023

@pytorchbot merge

1 similar comment
@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 12, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

## Exception in cond:
For code below:
```python
import torch
import functorch.experimental.control_flow as control_flow
def true_fn(x):
    return x.sin()

def false_fn(x):
    return x, x

def f(x, y):
    return control_flow.cond(y, true_fn, false_fn, [x])

f(torch.ones(3, 4), torch.tensor(False))
```
The original exception stack trace is:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1164, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 418, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Expected branch to return a single tensor

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```
After this PR we get:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 429, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 421, in speculate_branch
    unimplemented(
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 187, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Expected branch to return a single tensor

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
    f(torch.ones(3, 4), torch.tensor(False))
  File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
    return control_flow.cond(y, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
    return torch.compile(cond_op, backend="eager", fullgraph=True)(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1159, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```
## Exception during speculating branches
The example code below has a inplace-buffer mutation error,
```python    
import torch
import functorch.experimental.control_flow as control_flow

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.ones(6, 4))

    def forward(self, x):
        def true_fn(x):
            self.buffer += 1
            return self.buffer.sum() + x.sum()

        def false_fn(x):
            return (x - 1).sum()

        return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])

mod_for_compile = torch.compile(Foo(), backend="eager", dynamic=True)
mod_for_compile(torch.ones(3, 4))
```

Before this PR the exception looks like:
```python
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 163, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 394, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 222, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 415, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 405, in speculate_branch
    raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```


after this PR, the only difference is the error message of UncapturedHigherOrderOpError changes from `Cond doesn't work unless it is captured completely with torch.compile` to `Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break`.

```python
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 177, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1214, in STORE_ATTR
    .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
    raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 426, in call_function
    (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 236, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
    mod_for_compile(torch.ones(3, 4))
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
    raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
   File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
    return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
    return cond_op(pred, true_fn, false_fn, operands)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Sep 13, 2023
ghstack-source-id: 34ef143
Pull Request resolved: #108817
@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 13, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ydwu4/24/head branch September 17, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants