-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Repro:
import torch
# We must wrap the whole thing into a function to reproduce the error.
def test():
x = torch.ones(1)
def fn():
def inner():
return x + 2
return inner
@torch.compile
def start():
fn_inner = fn()
res = fn_inner()
return res, fn_inner
start()
test()Error log:
Traceback (most recent call last):
File "/Users/ryanguo99/Documents/work/scratch/test.py", line 21, in <module>
test()
File "/Users/ryanguo99/Documents/work/scratch/test.py", line 19, in test
start()
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 1294, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 1089, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 966, in _compile
raise InternalTorchDynamoError(
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 929, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 671, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 704, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
transformations(instructions, code_options)
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/convert_frame.py", line 639, in transform
tracer.run()
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2766, in run
super().run()
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 973, in run
while self.step():
^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 885, in step
self.dispatch_table[inst.opcode](self, inst)
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2269, in CALL
self._call(inst)
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2263, in _call
self.call_function(fn, args, kwargs)
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 820, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 111, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 826, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2981, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3028, in inline_call_
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 571, in bind_args
cand = cand.parent
^^^^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'InstructionTranslator' object has no attribute 'parent'
from user code:
File "/Users/ryanguo99/Documents/work/scratch/test.py", line 16, in start
res = fn_inner()
Versions
Collecting environment information...
PyTorch version: 2.6.0a0+git7bbdf87
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.3)
CMake version: version 3.30.3
Libc version: N/A
Python version: 3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 13:22:57) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==2.1.1
[pip3] optree==0.12.1
[pip3] torch==2.6.0a0+git851b973
[conda] numpy 2.1.1 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.6.0a0+git851b973 dev_0
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec