-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Open
Labels
actionabletriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Repro using triton's bleeding edge main:
#!/usr/bin/env python3
import torch
import torch.nn.attention.flex_attention
torch.set_default_device("cuda")
N_CTX = 4096
SLIDING_WINDOW = 128
def sliding_window_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx < SLIDING_WINDOW
return causal_mask & window_mask
def rand_qkv(n_batch: int, n_head: int, n_ctx: int, d_qk: int, d_v: int):
qk_shape = (n_batch, n_head, n_ctx, d_qk)
v_shape = (n_batch, n_head, n_ctx, d_qk)
return (torch.randn(qk_shape), torch.randn(qk_shape), torch.randn(v_shape))
n_batch = 1
n_head = 1
local_bm = torch.nn.attention.flex_attention.create_block_mask(
sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
)
flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention)
flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)
Here is the error we get:
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Error in subprocess
E0211 21:13:34.994000 1581518 subproc_pool.py:321] concurrent.futures.process._RemoteTraceback:
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
E0211 21:13:34.994000 1581518 subproc_pool.py:321] r = call_item.fn(*call_item.args, **call_item.kwargs)
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 340, in do_job
E0211 21:13:34.994000 1581518 subproc_pool.py:321] return pickler.dumps(result)
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 100, in dumps
E0211 21:13:34.994000 1581518 subproc_pool.py:321] return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] The above exception was the direct cause of the following exception:
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 319, in callback
E0211 21:13:34.994000 1581518 subproc_pool.py:321] result = future.result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
E0211 21:13:34.994000 1581518 subproc_pool.py:321] return self.__get_result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321] File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
E0211 21:13:34.994000 1581518 subproc_pool.py:321] raise self._exception
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
W0211 21:13:34.996000 1581373 pytorch/torch/_inductor/utils.py:875] [0/0] on error, temporary cache dir kept at /tmp/torchinductor_ubuntu/tmpkwuio_wu
Traceback (most recent call last):
File "/home/ubuntu/./test.py", line 28, in <module>
flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)
File "/home/ubuntu/pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1487, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1466, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_dynamo.py", line 131, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/ubuntu/pytorch/torch/__init__.py", line 2339, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2163, in compile_fx
return aot_autograd(
File "/home/ubuntu/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1168, in aot_module_simplified
compiled_fn = dispatch_and_compile()
File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/home/ubuntu/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 205, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 479, in __call__
return self.compiler_fn(gm, example_inputs)
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2038, in fw_compiler_base
return inner_compile(
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 623, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_aot.py", line 104, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 727, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1402, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1122, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 1990, in compile_to_module
return self._compile_to_module()
File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 2032, in _compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 2758, in load_by_key_path
mod = _reload_python_module(key, path)
File "/home/ubuntu/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_ubuntu/tmpkwuio_wu/2c/c2cwsb3k4rlb6akooercw4u4bjrnkofn6xx5cavzkj2swf2iyiii.py", line 552, in <module>
async_compile.wait(globals())
File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 421, in wait
scope[key] = result.result()
File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 3237, in result
return self.result_fn()
File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 311, in get_result
kernel = task.result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
We did find that sometimes the function does get cached and after that we don't see the bug, so you might want to run the reproducer with TORCHINDUCTOR_FORCE_DISABLE_CACHES=1
.
Originally posted by @saagarjha in #146417 (comment)
Metadata
Metadata
Assignees
Labels
actionabletriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module