-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Compiling fails for even very simple functions when tensors are complex-valued. See e.g.
import torch
@torch.compile
def foo(X, Y):
Z = X + Y
return Z
X = torch.zeros(10, dtype=torch.complex128)
Y = torch.zeros(10, dtype=torch.complex128)
foo(X, Y)Error logs
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File [~/torch/_dynamo/output_graph.py:670](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:670), in OutputGraph.call_user_compiler(self, gm)
[669](/torch/_dynamo/output_graph.py?line=668) else:
--> [670](/torch/_dynamo/output_graph.py?line=669) compiled_fn = compiler_fn(gm, self.fake_example_inputs())
[671](/torch/_dynamo/output_graph.py?line=670) _step_logger()(logging.INFO, f"done compiler function {name}")
File [~/torch/_dynamo/debug_utils.py:1055](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/debug_utils.py:1055), in wrap_backend_debug..debug_wrapper(gm, example_inputs, **kwargs)
[1054](/torch/_dynamo/debug_utils.py?line=1053) else:
-> [1055](/torch/_dynamo/debug_utils.py?line=1054) compiled_gm = compiler_fn(gm, example_inputs)
[1057](/torch/_dynamo/debug_utils.py?line=1056) return compiled_gm
File [~/torch/__init__.py:1390](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/__init__.py:1390), in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
[1388](/torch/__init__.py?line=1387) from torch._inductor.compile_fx import compile_fx
-> [1390](/torch/__init__.py?line=1389) return compile_fx(model_, inputs_, config_patches=self.config)
File [~/torch/_inductor/compile_fx.py:455](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:455), in compile_fx(model_, example_inputs_, inner_compile, config_patches)
[450](/torch/_inductor/compile_fx.py?line=449) with overrides.patch_functions():
[451](/torch/_inductor/compile_fx.py?line=450)
[452](/torch/_inductor/compile_fx.py?line=451) # TODO: can add logging before[/after](https://file+.vscode-resource.vscode-cdn.net/after) the call to create_aot_dispatcher_function
[453](/torch/_inductor/compile_fx.py?line=452) # in torch._functorch[/aot_autograd.py](https://file+.vscode-resource.vscode-cdn.net/aot_autograd.py)::aot_module_simplified::aot_function_simplified::new_func
[454](/torch/_inductor/compile_fx.py?line=453) # once torchdynamo is merged into pytorch
--> [455](/torch/_inductor/compile_fx.py?line=454) return aot_autograd(
[456](/torch/_inductor/compile_fx.py?line=455) fw_compiler=fw_compiler,
[457](/torch/_inductor/compile_fx.py?line=456) bw_compiler=bw_compiler,
[458](/torch/_inductor/compile_fx.py?line=457) decompositions=select_decomp_table(),
[459](/torch/_inductor/compile_fx.py?line=458) partition_fn=functools.partial(
[460](/torch/_inductor/compile_fx.py?line=459) min_cut_rematerialization_partition, compiler="inductor"
[461](/torch/_inductor/compile_fx.py?line=460) ),
[462](/torch/_inductor/compile_fx.py?line=461) keep_inference_input_mutations=True,
[463](/torch/_inductor/compile_fx.py?line=462) )(model_, example_inputs_)
File [~/torch/_dynamo/backends/common.py:48](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/backends/common.py:48), in aot_autograd..compiler_fn(gm, example_inputs)
[47](/torch/_dynamo/backends/common.py?line=46) with enable_aot_logging():
---> [48](/torch/_dynamo/backends/common.py?line=47) cg = aot_module_simplified(gm, example_inputs, **kwargs)
[49](/torch/_dynamo/backends/common.py?line=48) counters["aot_autograd"]["ok"] += 1
File [~/torch/_functorch/aot_autograd.py:2805](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:2805), in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, hasher_type, static_argnums, keep_inference_input_mutations)
[2803](/torch/_functorch/aot_autograd.py?line=2802) full_args.extend(args)
-> [2805](/torch/_functorch/aot_autograd.py?line=2804) compiled_fn = create_aot_dispatcher_function(
[2806](/torch/_functorch/aot_autograd.py?line=2805) functional_call,
[2807](/torch/_functorch/aot_autograd.py?line=2806) full_args,
[2808](/torch/_functorch/aot_autograd.py?line=2807) aot_config,
[2809](/torch/_functorch/aot_autograd.py?line=2808) )
[2811](/torch/_functorch/aot_autograd.py?line=2810) # TODO: There is something deeply wrong here; compiled_fn running with
[2812](/torch/_functorch/aot_autograd.py?line=2811) # the boxed calling convention, but aot_module_simplified somehow
[2813](/torch/_functorch/aot_autograd.py?line=2812) # historically returned a function that was not the boxed calling
[2814](/torch/_functorch/aot_autograd.py?line=2813) # convention. This should get fixed...
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_functorch/aot_autograd.py:2498](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:2498), in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
[2496](/torch/_functorch/aot_autograd.py?line=2495) # You can put more passes here
-> [2498](/torch/_functorch/aot_autograd.py?line=2497) compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
[2500](/torch/_functorch/aot_autograd.py?line=2499) if not hasattr(compiled_fn, "_boxed_call"):
File [~/torch/_functorch/aot_autograd.py:1713](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:1713), in aot_wrapper_dedupe(flat_fn, flat_args, aot_config, compiler_fn)
[1712](/torch/_functorch/aot_autograd.py?line=1711) if ok:
-> [1713](/torch/_functorch/aot_autograd.py?line=1712) return compiler_fn(flat_fn, leaf_flat_args, aot_config)
[1715](/torch/_functorch/aot_autograd.py?line=1714) # Strategy 2: Duplicate specialize.
[1716](/torch/_functorch/aot_autograd.py?line=1715) #
[1717](/torch/_functorch/aot_autograd.py?line=1716) # In Haskell types, suppose you have:
(...)
[1749](/torch/_functorch/aot_autograd.py?line=1748) # }
[1750](/torch/_functorch/aot_autograd.py?line=1749) # keep_arg_mask = [True, True, False, True]
File [~/torch/_functorch/aot_autograd.py:1326](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:1326), in aot_dispatch_base(flat_fn, flat_args, aot_config)
[1325](/torch/_functorch/aot_autograd.py?line=1324) with context(), track_graph_compiling(aot_config, "inference"):
-> [1326](/torch/_functorch/aot_autograd.py?line=1325) compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
[1328](/torch/_functorch/aot_autograd.py?line=1327) compiled_fn = create_runtime_wrapper(
[1329](/torch/_functorch/aot_autograd.py?line=1328) compiled_fw,
[1330](/torch/_functorch/aot_autograd.py?line=1329) runtime_metadata=metadata_,
[1331](/torch/_functorch/aot_autograd.py?line=1330) trace_joint=False,
[1332](/torch/_functorch/aot_autograd.py?line=1331) keep_input_mutations=aot_config.keep_inference_input_mutations
[1333](/torch/_functorch/aot_autograd.py?line=1332) )
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/compile_fx.py:430](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:430), in compile_fx..fw_compiler(model, example_inputs)
[429](/torch/_inductor/compile_fx.py?line=428) model = convert_outplace_to_inplace(model)
--> [430](/torch/_inductor/compile_fx.py?line=429) return inner_compile(
[431](/torch/_inductor/compile_fx.py?line=430) model,
[432](/torch/_inductor/compile_fx.py?line=431) example_inputs,
[433](/torch/_inductor/compile_fx.py?line=432) num_fixed=fixed,
[434](/torch/_inductor/compile_fx.py?line=433) cudagraphs=cudagraphs,
[435](/torch/_inductor/compile_fx.py?line=434) graph_id=graph_id,
[436](/torch/_inductor/compile_fx.py?line=435) )
File [~/torch/_dynamo/debug_utils.py:595](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/debug_utils.py:595), in wrap_compiler_debug..debug_wrapper(gm, example_inputs, **kwargs)
[594](/torch/_dynamo/debug_utils.py?line=593) else:
--> [595](/torch/_dynamo/debug_utils.py?line=594) compiled_fn = compiler_fn(gm, example_inputs)
[597](/torch/_dynamo/debug_utils.py?line=596) return compiled_fn
File [~/torch/_inductor/debug.py:239](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/debug.py:239), in DebugContext.wrap..inner(*args, **kwargs)
[238](/torch/_inductor/debug.py?line=237) with DebugContext():
--> [239](/torch/_inductor/debug.py?line=238) return fn(*args, **kwargs)
File [~/miniconda3/lib/python3.9/contextlib.py:79](https://file+.vscode-resource.vscode-cdn.net/Users/~/miniconda3/lib/python3.9/contextlib.py:79), in ContextDecorator.__call__..inner(*args, **kwds)
[78](/miniconda3/lib/python3.9/contextlib.py?line=77) with self._recreate_cm():
---> [79](/miniconda3/lib/python3.9/contextlib.py?line=78) return func(*args, **kwds)
File [~/torch/_inductor/compile_fx.py:177](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:177), in compile_fx_inner(gm, example_inputs, cudagraphs, num_fixed, is_backward, graph_id)
[176](/torch/_inductor/compile_fx.py?line=175) graph.run(*example_inputs)
--> [177](/torch/_inductor/compile_fx.py?line=176) compiled_fn = graph.compile_to_fn()
[179](/torch/_inductor/compile_fx.py?line=178) if cudagraphs:
File [~/torch/_inductor/graph.py:586](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:586), in GraphLowering.compile_to_fn(self)
[585](/torch/_inductor/graph.py?line=584) def compile_to_fn(self):
--> [586](/torch/_inductor/graph.py?line=585) return self.compile_to_module().call
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/graph.py:571](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:571), in GraphLowering.compile_to_module(self)
[569](/torch/_inductor/graph.py?line=568) from .codecache import PyCodeCache
--> [571](/torch/_inductor/graph.py?line=570) code = self.codegen()
[572](/torch/_inductor/graph.py?line=571) if config.debug:
File [~/torch/_inductor/graph.py:522](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:522), in GraphLowering.codegen(self)
[521](/torch/_inductor/graph.py?line=520) assert self.scheduler is not None # mypy can't figure this out
--> [522](/torch/_inductor/graph.py?line=521) self.scheduler.codegen()
[523](/torch/_inductor/graph.py?line=522) assert self.wrapper_code is not None
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/scheduler.py:1177](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/scheduler.py:1177), in Scheduler.codegen(self)
[1175](/torch/_inductor/scheduler.py?line=1174) self.available_buffer_names.update(node.get_names())
-> [1177](/torch/_inductor/scheduler.py?line=1176) self.flush()
File [~/torch/_inductor/scheduler.py:1095](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/scheduler.py:1095), in Scheduler.flush(self)
[1094](/torch/_inductor/scheduler.py?line=1093) for backend in self.backends.values():
-> [1095](/torch/_inductor/scheduler.py?line=1094) backend.flush()
[1096](/torch/_inductor/scheduler.py?line=1095) self.free_buffers()
File [~/torch/_inductor/codegen/cpp.py:1975](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/cpp.py:1975), in CppScheduling.flush(self)
[1974](/torch/_inductor/codegen/cpp.py?line=1973) def flush(self):
-> [1975](/torch/_inductor/codegen/cpp.py?line=1974) self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
[1976](/torch/_inductor/codegen/cpp.py?line=1975) self.get_kernel_group()
File [~/torch/_inductor/codegen/cpp.py:2004](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/cpp.py:2004), in KernelGroup.codegen_define_and_call(self, wrapper)
[2003](/torch/_inductor/codegen/cpp.py?line=2002) kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix()
-> [2004](/torch/_inductor/codegen/cpp.py?line=2003) arg_defs, call_args, arg_types = self.args.cpp_argdefs()
[2005](/torch/_inductor/codegen/cpp.py?line=2004) arg_defs = ",\n".ljust(25).join(arg_defs)
File [~/torch/_inductor/codegen/common.py:330](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/common.py:330), in KernelArgs.cpp_argdefs(self)
[329](/torch/_inductor/codegen/common.py?line=328) dtype = buffer_types[outer]
--> [330](/torch/_inductor/codegen/common.py?line=329) cpp_dtype = DTYPE_TO_CPP[dtype]
[331](/torch/_inductor/codegen/common.py?line=330) arg_defs.append(f"const {cpp_dtype}* __restrict__ {inner}")
KeyError: torch.complex128
The above exception was the direct cause of the following exception:
BackendCompilerFailed Traceback (most recent call last)
[/Users/notebooks/test_compile.py](https://file+.vscode-resource.vscode-cdn.net/Users/notebooks/test_compile.py) in line 9
[7](/notebooks/test_compile.py?line=6) X = torch.zeros(10, dtype=torch.complex128)
[8](/notebooks/test_compile.py?line=7) Y = torch.zeros(10, dtype=torch.complex128)
----> [9](/notebooks/test_compile.py?line=8) foo(X, Y)
File [~/torch/_dynamo/eval_frame.py:209](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/eval_frame.py:209), in _TorchDynamoContext.__call__.._fn(*args, **kwargs)
[207](/torch/_dynamo/eval_frame.py?line=206) dynamic_ctx.__enter__()
[208](/torch/_dynamo/eval_frame.py?line=207) try:
--> [209](/torch/_dynamo/eval_frame.py?line=208) return fn(*args, **kwargs)
[210](/torch/_dynamo/eval_frame.py?line=209) finally:
[211](/torch/_dynamo/eval_frame.py?line=210) set_eval_frame(prior)
File [~/torch/_dynamo/eval_frame.py:337](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/eval_frame.py:337), in catch_errors_wrapper..catch_errors(frame, cache_size)
[334](/torch/_dynamo/eval_frame.py?line=333) return hijacked_callback(frame, cache_size, hooks)
[336](/torch/_dynamo/eval_frame.py?line=335) with compile_lock:
--> [337](/torch/_dynamo/eval_frame.py?line=336) return callback(frame, cache_size, hooks)
File [~/torch/_dynamo/convert_frame.py:404](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:404), in convert_frame.._convert_frame(frame, cache_size, hooks)
[402](/torch/_dynamo/convert_frame.py?line=401) counters["frames"]["total"] += 1
[403](/torch/_dynamo/convert_frame.py?line=402) try:
--> [404](/torch/_dynamo/convert_frame.py?line=403) result = inner_convert(frame, cache_size, hooks)
[405](/torch/_dynamo/convert_frame.py?line=404) counters["frames"]["ok"] += 1
[406](/torch/_dynamo/convert_frame.py?line=405) return result
File [~/torch/_dynamo/convert_frame.py:104](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:104), in wrap_convert_context.._fn(*args, **kwargs)
[102](/torch/_dynamo/convert_frame.py?line=101) torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
[103](/torch/_dynamo/convert_frame.py?line=102) try:
--> [104](/torch/_dynamo/convert_frame.py?line=103) return fn(*args, **kwargs)
[105](/torch/_dynamo/convert_frame.py?line=104) finally:
[106](/torch/_dynamo/convert_frame.py?line=105) torch._C._set_grad_enabled(prior_grad_mode)
File [~/torch/_dynamo/convert_frame.py:262](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:262), in convert_frame_assert.._convert_frame_assert(frame, cache_size, hooks)
[259](/torch/_dynamo/convert_frame.py?line=258) global initial_grad_state
[260](/torch/_dynamo/convert_frame.py?line=259) initial_grad_state = torch.is_grad_enabled()
--> [262](/torch/_dynamo/convert_frame.py?line=261) return _compile(
[263](/torch/_dynamo/convert_frame.py?line=262) frame.f_code,
[264](/torch/_dynamo/convert_frame.py?line=263) frame.f_globals,
[265](/torch/_dynamo/convert_frame.py?line=264) frame.f_locals,
[266](/torch/_dynamo/convert_frame.py?line=265) frame.f_builtins,
[267](/torch/_dynamo/convert_frame.py?line=266) compiler_fn,
[268](/torch/_dynamo/convert_frame.py?line=267) one_graph,
[269](/torch/_dynamo/convert_frame.py?line=268) export,
[270](/torch/_dynamo/convert_frame.py?line=269) hooks,
[271](/torch/_dynamo/convert_frame.py?line=270) frame,
[272](/torch/_dynamo/convert_frame.py?line=271) )
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[161](/torch/_dynamo/utils.py?line=160) compilation_metrics[key] = []
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
[165](/torch/_dynamo/utils.py?line=164) # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File [~/torch/_dynamo/convert_frame.py:324](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:324), in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
[322](/torch/_dynamo/convert_frame.py?line=321) for attempt in itertools.count():
[323](/torch/_dynamo/convert_frame.py?line=322) try:
--> [324](/torch/_dynamo/convert_frame.py?line=323) out_code = transform_code_object(code, transform)
[325](/torch/_dynamo/convert_frame.py?line=324) orig_code_map[out_code] = code
[326](/torch/_dynamo/convert_frame.py?line=325) break
File [~/torch/_dynamo/bytecode_transformation.py:445](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/bytecode_transformation.py:445), in transform_code_object(code, transformations, safe)
[442](/torch/_dynamo/bytecode_transformation.py?line=441) instructions = cleaned_instructions(code, safe)
[443](/torch/_dynamo/bytecode_transformation.py?line=442) propagate_line_nums(instructions)
--> [445](/torch/_dynamo/bytecode_transformation.py?line=444) transformations(instructions, code_options)
[446](/torch/_dynamo/bytecode_transformation.py?line=445) return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File [~/torch/_dynamo/convert_frame.py:311](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:311), in _compile..transform(instructions, code_options)
[298](/torch/_dynamo/convert_frame.py?line=297) nonlocal output
[299](/torch/_dynamo/convert_frame.py?line=298) tracer = InstructionTranslator(
[300](/torch/_dynamo/convert_frame.py?line=299) instructions,
[301](/torch/_dynamo/convert_frame.py?line=300) code,
(...)
[309](/torch/_dynamo/convert_frame.py?line=308) mutated_closure_cell_contents,
[310](/torch/_dynamo/convert_frame.py?line=309) )
--> [311](/torch/_dynamo/convert_frame.py?line=310) tracer.run()
[312](/torch/_dynamo/convert_frame.py?line=311) output = tracer.output
[313](/torch/_dynamo/convert_frame.py?line=312) assert output is not None
File [~/torch/_dynamo/symbolic_convert.py:1726](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:1726), in InstructionTranslator.run(self)
[1724](/torch/_dynamo/symbolic_convert.py?line=1723) def run(self):
[1725](/torch/_dynamo/symbolic_convert.py?line=1724) _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> [1726](/torch/_dynamo/symbolic_convert.py?line=1725) super().run()
File [~/torch/_dynamo/symbolic_convert.py:576](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:576), in InstructionTranslatorBase.run(self)
[571](/torch/_dynamo/symbolic_convert.py?line=570) try:
[572](/torch/_dynamo/symbolic_convert.py?line=571) self.output.push_tx(self)
[573](/torch/_dynamo/symbolic_convert.py?line=572) while (
[574](/torch/_dynamo/symbolic_convert.py?line=573) self.instruction_pointer is not None
[575](/torch/_dynamo/symbolic_convert.py?line=574) and not self.output.should_exit
--> [576](/torch/_dynamo/symbolic_convert.py?line=575) and self.step()
[577](/torch/_dynamo/symbolic_convert.py?line=576) ):
[578](/torch/_dynamo/symbolic_convert.py?line=577) pass
[579](/torch/_dynamo/symbolic_convert.py?line=578) except BackendCompilerFailed:
File [~/torch/_dynamo/symbolic_convert.py:540](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:540), in InstructionTranslatorBase.step(self)
[538](/torch/_dynamo/symbolic_convert.py?line=537) if not hasattr(self, inst.opname):
[539](/torch/_dynamo/symbolic_convert.py?line=538) unimplemented(f"missing: {inst.opname}")
--> [540](/torch/_dynamo/symbolic_convert.py?line=539) getattr(self, inst.opname)(inst)
[542](/torch/_dynamo/symbolic_convert.py?line=541) return inst.opname != "RETURN_VALUE"
[543](/torch/_dynamo/symbolic_convert.py?line=542) except BackendCompilerFailed:
File [~/torch/_dynamo/symbolic_convert.py:1792](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:1792), in InstructionTranslator.RETURN_VALUE(self, inst)
[1787](/torch/_dynamo/symbolic_convert.py?line=1786) _step_logger()(
[1788](/torch/_dynamo/symbolic_convert.py?line=1787) logging.INFO,
[1789](/torch/_dynamo/symbolic_convert.py?line=1788) f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
[1790](/torch/_dynamo/symbolic_convert.py?line=1789) )
[1791](/torch/_dynamo/symbolic_convert.py?line=1790) log.debug("RETURN_VALUE triggered compile")
-> [1792](/torch/_dynamo/symbolic_convert.py?line=1791) self.output.compile_subgraph(
[1793](/torch/_dynamo/symbolic_convert.py?line=1792) self, reason=GraphCompileReason("return_value", [self.frame_summary()])
[1794](/torch/_dynamo/symbolic_convert.py?line=1793) )
[1795](/torch/_dynamo/symbolic_convert.py?line=1794) self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
File [~/torch/_dynamo/output_graph.py:517](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:517), in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
[503](/torch/_dynamo/output_graph.py?line=502) self.add_output_instructions(random_calls_instructions)
[505](/torch/_dynamo/output_graph.py?line=504) if (
[506](/torch/_dynamo/output_graph.py?line=505) stack_values
[507](/torch/_dynamo/output_graph.py?line=506) and all(
(...)
[514](/torch/_dynamo/output_graph.py?line=513)
[515](/torch/_dynamo/output_graph.py?line=514) # optimization to generate better code in a common case
[516](/torch/_dynamo/output_graph.py?line=515) self.add_output_instructions(
--> [517](/torch/_dynamo/output_graph.py?line=516) self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
[518](/torch/_dynamo/output_graph.py?line=517) + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
[519](/torch/_dynamo/output_graph.py?line=518) )
[520](/torch/_dynamo/output_graph.py?line=519) else:
[521](/torch/_dynamo/output_graph.py?line=520) graph_output_var = self.new_var("graph_out")
File [~/torch/_dynamo/output_graph.py:588](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:588), in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
[586](/torch/_dynamo/output_graph.py?line=585) assert_no_fake_params_or_buffers(gm)
[587](/torch/_dynamo/output_graph.py?line=586) with tracing(self.tracing_context):
--> [588](/torch/_dynamo/output_graph.py?line=587) compiled_fn = self.call_user_compiler(gm)
[589](/torch/_dynamo/output_graph.py?line=588) compiled_fn = disable(compiled_fn)
[591](/torch/_dynamo/output_graph.py?line=590) counters["stats"]["unique_graphs"] += 1
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
[161](/torch/_dynamo/utils.py?line=160) compilation_metrics[key] = []
[162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
[164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
[165](/torch/_dynamo/utils.py?line=164) # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File [~/torch/_dynamo/output_graph.py:675](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:675), in OutputGraph.call_user_compiler(self, gm)
[673](/torch/_dynamo/output_graph.py?line=672) except Exception as e:
[674](/torch/_dynamo/output_graph.py?line=673) compiled_fn = gm.forward
--> [675](/torch/_dynamo/output_graph.py?line=674) raise BackendCompilerFailed(self.compiler_fn, e) from e
[676](/torch/_dynamo/output_graph.py?line=675) return compiled_fn
BackendCompilerFailed: debug_wrapper raised KeyError: torch.complex128
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = TrueMinified repro
No response
Versions
PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.2.1 (x86_64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.16 (main, Jan 11 2023, 10:02:19) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.2
[pip3] pytorch-lightning==2.0.0
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchcde==0.2.5
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.4
[pip3] torchqdynamics==0.1.0
[pip3] torchsde==0.2.5
[pip3] torchvision==0.15.1
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py39h9ed2024_0
[conda] mkl_fft 1.3.1 py39h4ab4a9b_0
[conda] mkl_random 1.2.2 py39hb2f4e1b_0
[conda] numpy 1.24.2 pypi_0 pypi
[conda] pytorch-lightning 2.0.0 pypi_0 pypi
[conda] torch 2.0.0 pypi_0 pypi
[conda] torchaudio 2.0.1 pypi_0 pypi
[conda] torchcde 0.2.5 pypi_0 pypi
[conda] torchdiffeq 0.2.3 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchqdynamics 0.1.0 pypi_0 pypi
[conda] torchsde 0.2.5 pypi_0 pypi
[conda] torchvision 0.15.1 pypi_0 pypi
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire