- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
Description
🐛 Describe the bug
Compiling fails for even very simple functions when tensors are complex-valued. See e.g.
import torch
@torch.compile
def foo(X, Y):
    Z = X + Y
    return Z
X = torch.zeros(10, dtype=torch.complex128)
Y = torch.zeros(10, dtype=torch.complex128)
foo(X, Y)Error logs
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File [~/torch/_dynamo/output_graph.py:670](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:670), in OutputGraph.call_user_compiler(self, gm)
    [669](/torch/_dynamo/output_graph.py?line=668) else:
--> [670](/torch/_dynamo/output_graph.py?line=669)     compiled_fn = compiler_fn(gm, self.fake_example_inputs())
    [671](/torch/_dynamo/output_graph.py?line=670) _step_logger()(logging.INFO, f"done compiler function {name}")
File [~/torch/_dynamo/debug_utils.py:1055](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/debug_utils.py:1055), in wrap_backend_debug..debug_wrapper(gm, example_inputs, **kwargs)
   [1054](/torch/_dynamo/debug_utils.py?line=1053) else:
-> [1055](/torch/_dynamo/debug_utils.py?line=1054)     compiled_gm = compiler_fn(gm, example_inputs)
   [1057](/torch/_dynamo/debug_utils.py?line=1056) return compiled_gm
File [~/torch/__init__.py:1390](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/__init__.py:1390), in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
   [1388](/torch/__init__.py?line=1387) from torch._inductor.compile_fx import compile_fx
-> [1390](/torch/__init__.py?line=1389) return compile_fx(model_, inputs_, config_patches=self.config)
File [~/torch/_inductor/compile_fx.py:455](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:455), in compile_fx(model_, example_inputs_, inner_compile, config_patches)
    [450](/torch/_inductor/compile_fx.py?line=449) with overrides.patch_functions():
    [451](/torch/_inductor/compile_fx.py?line=450) 
    [452](/torch/_inductor/compile_fx.py?line=451)     # TODO: can add logging before[/after](https://file+.vscode-resource.vscode-cdn.net/after) the call to create_aot_dispatcher_function
    [453](/torch/_inductor/compile_fx.py?line=452)     # in torch._functorch[/aot_autograd.py](https://file+.vscode-resource.vscode-cdn.net/aot_autograd.py)::aot_module_simplified::aot_function_simplified::new_func
    [454](/torch/_inductor/compile_fx.py?line=453)     # once torchdynamo is merged into pytorch
--> [455](/torch/_inductor/compile_fx.py?line=454)     return aot_autograd(
    [456](/torch/_inductor/compile_fx.py?line=455)         fw_compiler=fw_compiler,
    [457](/torch/_inductor/compile_fx.py?line=456)         bw_compiler=bw_compiler,
    [458](/torch/_inductor/compile_fx.py?line=457)         decompositions=select_decomp_table(),
    [459](/torch/_inductor/compile_fx.py?line=458)         partition_fn=functools.partial(
    [460](/torch/_inductor/compile_fx.py?line=459)             min_cut_rematerialization_partition, compiler="inductor"
    [461](/torch/_inductor/compile_fx.py?line=460)         ),
    [462](/torch/_inductor/compile_fx.py?line=461)         keep_inference_input_mutations=True,
    [463](/torch/_inductor/compile_fx.py?line=462)     )(model_, example_inputs_)
File [~/torch/_dynamo/backends/common.py:48](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/backends/common.py:48), in aot_autograd..compiler_fn(gm, example_inputs)
     [47](/torch/_dynamo/backends/common.py?line=46) with enable_aot_logging():
---> [48](/torch/_dynamo/backends/common.py?line=47)     cg = aot_module_simplified(gm, example_inputs, **kwargs)
     [49](/torch/_dynamo/backends/common.py?line=48)     counters["aot_autograd"]["ok"] += 1
File [~/torch/_functorch/aot_autograd.py:2805](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:2805), in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, hasher_type, static_argnums, keep_inference_input_mutations)
   [2803](/torch/_functorch/aot_autograd.py?line=2802) full_args.extend(args)
-> [2805](/torch/_functorch/aot_autograd.py?line=2804) compiled_fn = create_aot_dispatcher_function(
   [2806](/torch/_functorch/aot_autograd.py?line=2805)     functional_call,
   [2807](/torch/_functorch/aot_autograd.py?line=2806)     full_args,
   [2808](/torch/_functorch/aot_autograd.py?line=2807)     aot_config,
   [2809](/torch/_functorch/aot_autograd.py?line=2808) )
   [2811](/torch/_functorch/aot_autograd.py?line=2810) # TODO: There is something deeply wrong here; compiled_fn running with
   [2812](/torch/_functorch/aot_autograd.py?line=2811) # the boxed calling convention, but aot_module_simplified somehow
   [2813](/torch/_functorch/aot_autograd.py?line=2812) # historically returned a function that was not the boxed calling
   [2814](/torch/_functorch/aot_autograd.py?line=2813) # convention.  This should get fixed...
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_functorch/aot_autograd.py:2498](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:2498), in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
   [2496](/torch/_functorch/aot_autograd.py?line=2495) # You can put more passes here
-> [2498](/torch/_functorch/aot_autograd.py?line=2497) compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
   [2500](/torch/_functorch/aot_autograd.py?line=2499) if not hasattr(compiled_fn, "_boxed_call"):
File [~/torch/_functorch/aot_autograd.py:1713](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:1713), in aot_wrapper_dedupe(flat_fn, flat_args, aot_config, compiler_fn)
   [1712](/torch/_functorch/aot_autograd.py?line=1711)     if ok:
-> [1713](/torch/_functorch/aot_autograd.py?line=1712)         return compiler_fn(flat_fn, leaf_flat_args, aot_config)
   [1715](/torch/_functorch/aot_autograd.py?line=1714) # Strategy 2: Duplicate specialize.
   [1716](/torch/_functorch/aot_autograd.py?line=1715) #
   [1717](/torch/_functorch/aot_autograd.py?line=1716) # In Haskell types, suppose you have:
   (...)
   [1749](/torch/_functorch/aot_autograd.py?line=1748) #   }
   [1750](/torch/_functorch/aot_autograd.py?line=1749) #   keep_arg_mask = [True, True, False, True]
File [~/torch/_functorch/aot_autograd.py:1326](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_functorch/aot_autograd.py:1326), in aot_dispatch_base(flat_fn, flat_args, aot_config)
   [1325](/torch/_functorch/aot_autograd.py?line=1324) with context(), track_graph_compiling(aot_config, "inference"):
-> [1326](/torch/_functorch/aot_autograd.py?line=1325)     compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
   [1328](/torch/_functorch/aot_autograd.py?line=1327) compiled_fn = create_runtime_wrapper(
   [1329](/torch/_functorch/aot_autograd.py?line=1328)     compiled_fw,
   [1330](/torch/_functorch/aot_autograd.py?line=1329)     runtime_metadata=metadata_,
   [1331](/torch/_functorch/aot_autograd.py?line=1330)     trace_joint=False,
   [1332](/torch/_functorch/aot_autograd.py?line=1331)     keep_input_mutations=aot_config.keep_inference_input_mutations
   [1333](/torch/_functorch/aot_autograd.py?line=1332) )
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/compile_fx.py:430](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:430), in compile_fx..fw_compiler(model, example_inputs)
    [429](/torch/_inductor/compile_fx.py?line=428) model = convert_outplace_to_inplace(model)
--> [430](/torch/_inductor/compile_fx.py?line=429) return inner_compile(
    [431](/torch/_inductor/compile_fx.py?line=430)     model,
    [432](/torch/_inductor/compile_fx.py?line=431)     example_inputs,
    [433](/torch/_inductor/compile_fx.py?line=432)     num_fixed=fixed,
    [434](/torch/_inductor/compile_fx.py?line=433)     cudagraphs=cudagraphs,
    [435](/torch/_inductor/compile_fx.py?line=434)     graph_id=graph_id,
    [436](/torch/_inductor/compile_fx.py?line=435) )
File [~/torch/_dynamo/debug_utils.py:595](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/debug_utils.py:595), in wrap_compiler_debug..debug_wrapper(gm, example_inputs, **kwargs)
    [594](/torch/_dynamo/debug_utils.py?line=593) else:
--> [595](/torch/_dynamo/debug_utils.py?line=594)     compiled_fn = compiler_fn(gm, example_inputs)
    [597](/torch/_dynamo/debug_utils.py?line=596) return compiled_fn
File [~/torch/_inductor/debug.py:239](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/debug.py:239), in DebugContext.wrap..inner(*args, **kwargs)
    [238](/torch/_inductor/debug.py?line=237) with DebugContext():
--> [239](/torch/_inductor/debug.py?line=238)     return fn(*args, **kwargs)
File [~/miniconda3/lib/python3.9/contextlib.py:79](https://file+.vscode-resource.vscode-cdn.net/Users/~/miniconda3/lib/python3.9/contextlib.py:79), in ContextDecorator.__call__..inner(*args, **kwds)
     [78](/miniconda3/lib/python3.9/contextlib.py?line=77) with self._recreate_cm():
---> [79](/miniconda3/lib/python3.9/contextlib.py?line=78)     return func(*args, **kwds)
File [~/torch/_inductor/compile_fx.py:177](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/compile_fx.py:177), in compile_fx_inner(gm, example_inputs, cudagraphs, num_fixed, is_backward, graph_id)
    [176](/torch/_inductor/compile_fx.py?line=175)         graph.run(*example_inputs)
--> [177](/torch/_inductor/compile_fx.py?line=176)         compiled_fn = graph.compile_to_fn()
    [179](/torch/_inductor/compile_fx.py?line=178) if cudagraphs:
File [~/torch/_inductor/graph.py:586](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:586), in GraphLowering.compile_to_fn(self)
    [585](/torch/_inductor/graph.py?line=584) def compile_to_fn(self):
--> [586](/torch/_inductor/graph.py?line=585)     return self.compile_to_module().call
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/graph.py:571](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:571), in GraphLowering.compile_to_module(self)
    [569](/torch/_inductor/graph.py?line=568) from .codecache import PyCodeCache
--> [571](/torch/_inductor/graph.py?line=570) code = self.codegen()
    [572](/torch/_inductor/graph.py?line=571) if config.debug:
File [~/torch/_inductor/graph.py:522](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/graph.py:522), in GraphLowering.codegen(self)
    [521](/torch/_inductor/graph.py?line=520) assert self.scheduler is not None  # mypy can't figure this out
--> [522](/torch/_inductor/graph.py?line=521) self.scheduler.codegen()
    [523](/torch/_inductor/graph.py?line=522) assert self.wrapper_code is not None
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
File [~/torch/_inductor/scheduler.py:1177](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/scheduler.py:1177), in Scheduler.codegen(self)
   [1175](/torch/_inductor/scheduler.py?line=1174)     self.available_buffer_names.update(node.get_names())
-> [1177](/torch/_inductor/scheduler.py?line=1176) self.flush()
File [~/torch/_inductor/scheduler.py:1095](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/scheduler.py:1095), in Scheduler.flush(self)
   [1094](/torch/_inductor/scheduler.py?line=1093) for backend in self.backends.values():
-> [1095](/torch/_inductor/scheduler.py?line=1094)     backend.flush()
   [1096](/torch/_inductor/scheduler.py?line=1095) self.free_buffers()
File [~/torch/_inductor/codegen/cpp.py:1975](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/cpp.py:1975), in CppScheduling.flush(self)
   [1974](/torch/_inductor/codegen/cpp.py?line=1973) def flush(self):
-> [1975](/torch/_inductor/codegen/cpp.py?line=1974)     self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
   [1976](/torch/_inductor/codegen/cpp.py?line=1975)     self.get_kernel_group()
File [~/torch/_inductor/codegen/cpp.py:2004](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/cpp.py:2004), in KernelGroup.codegen_define_and_call(self, wrapper)
   [2003](/torch/_inductor/codegen/cpp.py?line=2002) kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix()
-> [2004](/torch/_inductor/codegen/cpp.py?line=2003) arg_defs, call_args, arg_types = self.args.cpp_argdefs()
   [2005](/torch/_inductor/codegen/cpp.py?line=2004) arg_defs = ",\n".ljust(25).join(arg_defs)
File [~/torch/_inductor/codegen/common.py:330](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_inductor/codegen/common.py:330), in KernelArgs.cpp_argdefs(self)
    [329](/torch/_inductor/codegen/common.py?line=328) dtype = buffer_types[outer]
--> [330](/torch/_inductor/codegen/common.py?line=329) cpp_dtype = DTYPE_TO_CPP[dtype]
    [331](/torch/_inductor/codegen/common.py?line=330) arg_defs.append(f"const {cpp_dtype}* __restrict__ {inner}")
KeyError: torch.complex128
The above exception was the direct cause of the following exception:
BackendCompilerFailed                     Traceback (most recent call last)
[/Users/notebooks/test_compile.py](https://file+.vscode-resource.vscode-cdn.net/Users/notebooks/test_compile.py) in line 9
      [7](/notebooks/test_compile.py?line=6) X = torch.zeros(10, dtype=torch.complex128)
      [8](/notebooks/test_compile.py?line=7) Y = torch.zeros(10, dtype=torch.complex128)
----> [9](/notebooks/test_compile.py?line=8) foo(X, Y)
File [~/torch/_dynamo/eval_frame.py:209](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/eval_frame.py:209), in _TorchDynamoContext.__call__.._fn(*args, **kwargs)
    [207](/torch/_dynamo/eval_frame.py?line=206) dynamic_ctx.__enter__()
    [208](/torch/_dynamo/eval_frame.py?line=207) try:
--> [209](/torch/_dynamo/eval_frame.py?line=208)     return fn(*args, **kwargs)
    [210](/torch/_dynamo/eval_frame.py?line=209) finally:
    [211](/torch/_dynamo/eval_frame.py?line=210)     set_eval_frame(prior)
File [~/torch/_dynamo/eval_frame.py:337](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/eval_frame.py:337), in catch_errors_wrapper..catch_errors(frame, cache_size)
    [334](/torch/_dynamo/eval_frame.py?line=333)             return hijacked_callback(frame, cache_size, hooks)
    [336](/torch/_dynamo/eval_frame.py?line=335) with compile_lock:
--> [337](/torch/_dynamo/eval_frame.py?line=336)     return callback(frame, cache_size, hooks)
File [~/torch/_dynamo/convert_frame.py:404](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:404), in convert_frame.._convert_frame(frame, cache_size, hooks)
    [402](/torch/_dynamo/convert_frame.py?line=401) counters["frames"]["total"] += 1
    [403](/torch/_dynamo/convert_frame.py?line=402) try:
--> [404](/torch/_dynamo/convert_frame.py?line=403)     result = inner_convert(frame, cache_size, hooks)
    [405](/torch/_dynamo/convert_frame.py?line=404)     counters["frames"]["ok"] += 1
    [406](/torch/_dynamo/convert_frame.py?line=405)     return result
File [~/torch/_dynamo/convert_frame.py:104](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:104), in wrap_convert_context.._fn(*args, **kwargs)
    [102](/torch/_dynamo/convert_frame.py?line=101) torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    [103](/torch/_dynamo/convert_frame.py?line=102) try:
--> [104](/torch/_dynamo/convert_frame.py?line=103)     return fn(*args, **kwargs)
    [105](/torch/_dynamo/convert_frame.py?line=104) finally:
    [106](/torch/_dynamo/convert_frame.py?line=105)     torch._C._set_grad_enabled(prior_grad_mode)
File [~/torch/_dynamo/convert_frame.py:262](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:262), in convert_frame_assert.._convert_frame_assert(frame, cache_size, hooks)
    [259](/torch/_dynamo/convert_frame.py?line=258) global initial_grad_state
    [260](/torch/_dynamo/convert_frame.py?line=259) initial_grad_state = torch.is_grad_enabled()
--> [262](/torch/_dynamo/convert_frame.py?line=261) return _compile(
    [263](/torch/_dynamo/convert_frame.py?line=262)     frame.f_code,
    [264](/torch/_dynamo/convert_frame.py?line=263)     frame.f_globals,
    [265](/torch/_dynamo/convert_frame.py?line=264)     frame.f_locals,
    [266](/torch/_dynamo/convert_frame.py?line=265)     frame.f_builtins,
    [267](/torch/_dynamo/convert_frame.py?line=266)     compiler_fn,
    [268](/torch/_dynamo/convert_frame.py?line=267)     one_graph,
    [269](/torch/_dynamo/convert_frame.py?line=268)     export,
    [270](/torch/_dynamo/convert_frame.py?line=269)     hooks,
    [271](/torch/_dynamo/convert_frame.py?line=270)     frame,
    [272](/torch/_dynamo/convert_frame.py?line=271) )
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [161](/torch/_dynamo/utils.py?line=160)     compilation_metrics[key] = []
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
    [165](/torch/_dynamo/utils.py?line=164) # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File [~/torch/_dynamo/convert_frame.py:324](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:324), in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    [322](/torch/_dynamo/convert_frame.py?line=321) for attempt in itertools.count():
    [323](/torch/_dynamo/convert_frame.py?line=322)     try:
--> [324](/torch/_dynamo/convert_frame.py?line=323)         out_code = transform_code_object(code, transform)
    [325](/torch/_dynamo/convert_frame.py?line=324)         orig_code_map[out_code] = code
    [326](/torch/_dynamo/convert_frame.py?line=325)         break
File [~/torch/_dynamo/bytecode_transformation.py:445](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/bytecode_transformation.py:445), in transform_code_object(code, transformations, safe)
    [442](/torch/_dynamo/bytecode_transformation.py?line=441) instructions = cleaned_instructions(code, safe)
    [443](/torch/_dynamo/bytecode_transformation.py?line=442) propagate_line_nums(instructions)
--> [445](/torch/_dynamo/bytecode_transformation.py?line=444) transformations(instructions, code_options)
    [446](/torch/_dynamo/bytecode_transformation.py?line=445) return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File [~/torch/_dynamo/convert_frame.py:311](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/convert_frame.py:311), in _compile..transform(instructions, code_options)
    [298](/torch/_dynamo/convert_frame.py?line=297) nonlocal output
    [299](/torch/_dynamo/convert_frame.py?line=298) tracer = InstructionTranslator(
    [300](/torch/_dynamo/convert_frame.py?line=299)     instructions,
    [301](/torch/_dynamo/convert_frame.py?line=300)     code,
   (...)
    [309](/torch/_dynamo/convert_frame.py?line=308)     mutated_closure_cell_contents,
    [310](/torch/_dynamo/convert_frame.py?line=309) )
--> [311](/torch/_dynamo/convert_frame.py?line=310) tracer.run()
    [312](/torch/_dynamo/convert_frame.py?line=311) output = tracer.output
    [313](/torch/_dynamo/convert_frame.py?line=312) assert output is not None
File [~/torch/_dynamo/symbolic_convert.py:1726](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:1726), in InstructionTranslator.run(self)
   [1724](/torch/_dynamo/symbolic_convert.py?line=1723) def run(self):
   [1725](/torch/_dynamo/symbolic_convert.py?line=1724)     _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> [1726](/torch/_dynamo/symbolic_convert.py?line=1725)     super().run()
File [~/torch/_dynamo/symbolic_convert.py:576](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:576), in InstructionTranslatorBase.run(self)
    [571](/torch/_dynamo/symbolic_convert.py?line=570) try:
    [572](/torch/_dynamo/symbolic_convert.py?line=571)     self.output.push_tx(self)
    [573](/torch/_dynamo/symbolic_convert.py?line=572)     while (
    [574](/torch/_dynamo/symbolic_convert.py?line=573)         self.instruction_pointer is not None
    [575](/torch/_dynamo/symbolic_convert.py?line=574)         and not self.output.should_exit
--> [576](/torch/_dynamo/symbolic_convert.py?line=575)         and self.step()
    [577](/torch/_dynamo/symbolic_convert.py?line=576)     ):
    [578](/torch/_dynamo/symbolic_convert.py?line=577)         pass
    [579](/torch/_dynamo/symbolic_convert.py?line=578) except BackendCompilerFailed:
File [~/torch/_dynamo/symbolic_convert.py:540](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:540), in InstructionTranslatorBase.step(self)
    [538](/torch/_dynamo/symbolic_convert.py?line=537)     if not hasattr(self, inst.opname):
    [539](/torch/_dynamo/symbolic_convert.py?line=538)         unimplemented(f"missing: {inst.opname}")
--> [540](/torch/_dynamo/symbolic_convert.py?line=539)     getattr(self, inst.opname)(inst)
    [542](/torch/_dynamo/symbolic_convert.py?line=541)     return inst.opname != "RETURN_VALUE"
    [543](/torch/_dynamo/symbolic_convert.py?line=542) except BackendCompilerFailed:
File [~/torch/_dynamo/symbolic_convert.py:1792](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/symbolic_convert.py:1792), in InstructionTranslator.RETURN_VALUE(self, inst)
   [1787](/torch/_dynamo/symbolic_convert.py?line=1786) _step_logger()(
   [1788](/torch/_dynamo/symbolic_convert.py?line=1787)     logging.INFO,
   [1789](/torch/_dynamo/symbolic_convert.py?line=1788)     f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
   [1790](/torch/_dynamo/symbolic_convert.py?line=1789) )
   [1791](/torch/_dynamo/symbolic_convert.py?line=1790) log.debug("RETURN_VALUE triggered compile")
-> [1792](/torch/_dynamo/symbolic_convert.py?line=1791) self.output.compile_subgraph(
   [1793](/torch/_dynamo/symbolic_convert.py?line=1792)     self, reason=GraphCompileReason("return_value", [self.frame_summary()])
   [1794](/torch/_dynamo/symbolic_convert.py?line=1793) )
   [1795](/torch/_dynamo/symbolic_convert.py?line=1794) self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
File [~/torch/_dynamo/output_graph.py:517](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:517), in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
    [503](/torch/_dynamo/output_graph.py?line=502)     self.add_output_instructions(random_calls_instructions)
    [505](/torch/_dynamo/output_graph.py?line=504) if (
    [506](/torch/_dynamo/output_graph.py?line=505)     stack_values
    [507](/torch/_dynamo/output_graph.py?line=506)     and all(
   (...)
    [514](/torch/_dynamo/output_graph.py?line=513) 
    [515](/torch/_dynamo/output_graph.py?line=514)     # optimization to generate better code in a common case
    [516](/torch/_dynamo/output_graph.py?line=515)     self.add_output_instructions(
--> [517](/torch/_dynamo/output_graph.py?line=516)         self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
    [518](/torch/_dynamo/output_graph.py?line=517)         + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
    [519](/torch/_dynamo/output_graph.py?line=518)     )
    [520](/torch/_dynamo/output_graph.py?line=519) else:
    [521](/torch/_dynamo/output_graph.py?line=520)     graph_output_var = self.new_var("graph_out")
File [~/torch/_dynamo/output_graph.py:588](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:588), in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
    [586](/torch/_dynamo/output_graph.py?line=585) assert_no_fake_params_or_buffers(gm)
    [587](/torch/_dynamo/output_graph.py?line=586) with tracing(self.tracing_context):
--> [588](/torch/_dynamo/output_graph.py?line=587)     compiled_fn = self.call_user_compiler(gm)
    [589](/torch/_dynamo/output_graph.py?line=588) compiled_fn = disable(compiled_fn)
    [591](/torch/_dynamo/output_graph.py?line=590) counters["stats"]["unique_graphs"] += 1
File [~/torch/_dynamo/utils.py:163](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/utils.py:163), in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
    [161](/torch/_dynamo/utils.py?line=160)     compilation_metrics[key] = []
    [162](/torch/_dynamo/utils.py?line=161) t0 = time.time()
--> [163](/torch/_dynamo/utils.py?line=162) r = func(*args, **kwargs)
    [164](/torch/_dynamo/utils.py?line=163) time_spent = time.time() - t0
    [165](/torch/_dynamo/utils.py?line=164) # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File [~/torch/_dynamo/output_graph.py:675](https://file+.vscode-resource.vscode-cdn.net/Users/~/torch/_dynamo/output_graph.py:675), in OutputGraph.call_user_compiler(self, gm)
    [673](/torch/_dynamo/output_graph.py?line=672) except Exception as e:
    [674](/torch/_dynamo/output_graph.py?line=673)     compiled_fn = gm.forward
--> [675](/torch/_dynamo/output_graph.py?line=674)     raise BackendCompilerFailed(self.compiler_fn, e) from e
    [676](/torch/_dynamo/output_graph.py?line=675) return compiled_fn
BackendCompilerFailed: debug_wrapper raised KeyError: torch.complex128
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = TrueMinified repro
No response
Versions
PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.2.1 (x86_64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.16 (main, Jan 11 2023, 10:02:19)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.2
[pip3] pytorch-lightning==2.0.0
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchcde==0.2.5
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.4
[pip3] torchqdynamics==0.1.0
[pip3] torchsde==0.2.5
[pip3] torchvision==0.15.1
[conda] blas                      1.0                         mkl
[conda] mkl                       2021.4.0           hecd8cb5_637
[conda] mkl-service               2.4.0            py39h9ed2024_0
[conda] mkl_fft                   1.3.1            py39h4ab4a9b_0
[conda] mkl_random                1.2.2            py39hb2f4e1b_0
[conda] numpy                     1.24.2                   pypi_0    pypi
[conda] pytorch-lightning         2.0.0                    pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torchaudio                2.0.1                    pypi_0    pypi
[conda] torchcde                  0.2.5                    pypi_0    pypi
[conda] torchdiffeq               0.2.3                    pypi_0    pypi
[conda] torchmetrics              0.11.4                   pypi_0    pypi
[conda] torchqdynamics            0.1.0                    pypi_0    pypi
[conda] torchsde                  0.2.5                    pypi_0    pypi
[conda] torchvision               0.15.1                   pypi_0    pypi
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire