-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Accord to 2023-11-08 nightly release test, there are a lot of CV models crashed by --freezing.
For example, RN50 can pass without freezing but crashed on freezing enabled with below info:
cpu eval resnet50
ERROR:common:Backend dynamo failed in warmup()
Traceback (most recent call last):
File "/workspace/pytorch/benchmarks/dynamo/common.py", line 2604, in warmup
fn(model, example_inputs)
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 409, in _fn
return fn(*args, **kwargs)
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 570, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 670, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 377, in _convert_frame_assert
return _compile(
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 594, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/workspace/pytorch/torch/_dynamo/utils.py", line 222, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 511, in compile_inner
out_code = transform_code_object(code, transform)
File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 150, in _fn
return fn(*args, **kwargs)
File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 476, in transform
tracer.run()
File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2120, in run
super().run()
File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 815, in run
and self.step()
File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 778, in step
getattr(self, inst.opname)(inst)
File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2235, in RETURN_VALUE
self.output.compile_subgraph(
File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 893, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 1038, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/workspace/pytorch/torch/_dynamo/utils.py", line 222, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 1109, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 1090, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/workspace/pytorch/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/workspace/pytorch/torch/_dynamo/backends/inductor.py", line 9, in inductor
return compile_fx(*args, **kwargs)
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1226, in compile_fx
return aot_autograd(
File "/workspace/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 4837, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/workspace/pytorch/torch/_dynamo/utils.py", line 222, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 4376, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 2724, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 2911, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 2073, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 975, in fw_compiler_freezing
optimized_function = inner_compile(
File "/workspace/pytorch/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/workspace/pytorch/torch/_inductor/debug.py", line 303, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 390, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 600, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/workspace/pytorch/torch/_dynamo/utils.py", line 222, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/pytorch/torch/_inductor/graph.py", line 444, in run
return super().run(*args)
File "/workspace/pytorch/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
File "/workspace/pytorch/torch/_inductor/graph.py", line 754, in run_node
result = super().run_node(n)
File "/workspace/pytorch/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/workspace/pytorch/torch/_inductor/graph.py", line 584, in call_function
return target(*args, **kwargs)
File "/workspace/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 234, in fn
return L[computation_op](*computation_args)
File "/workspace/pytorch/torch/_inductor/lowering.py", line 288, in wrapped
out = decomp_fn(*args, **kwargs)
File "/workspace/pytorch/torch/_inductor/lowering.py", line 1266, in convolution_unary
ir.ConvolutionUnary.create(
File "/workspace/pytorch/torch/_inductor/ir.py", line 4897, in create
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
File "/workspace/pytorch/torch/_inductor/ir.py", line 4717, in _prepare_convolution_fusion_create
assert 0 < len(padding) <= dims
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Versions
SW information:
| SW | Nightly commit | Main commit |
|---|---|---|
| Pytorch | 4060c20 | aa376e3 |
| Torchbench | / | 7617d3f5 |
| torchaudio | 475b6ae | ede4309 |
| torchtext | 142d029 | 45e4b8c |
| torchvision | 37ceb68 | 15c166a |
| torchdata | eb9bf61 | d76d92c |
| dynamo_benchmarks | nightly | / |
Repro:
bash inductor_single_run.sh multiple inference performance torchbench resnet50 float32 first static defaultSuspected guilty commit: 611a745
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov
Metadata
Metadata
Assignees
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module