KEMBAR78
New CachingAutotuner pickling logic may be brittle to triton upgrades · Issue #146945 · pytorch/pytorch · GitHub
Skip to content

New CachingAutotuner pickling logic may be brittle to triton upgrades #146945

@jamesjwu

Description

@jamesjwu

Repro using triton's bleeding edge main:

#!/usr/bin/env python3

import torch
import torch.nn.attention.flex_attention

torch.set_default_device("cuda")

N_CTX = 4096
SLIDING_WINDOW = 128

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx < SLIDING_WINDOW
    return causal_mask & window_mask

def rand_qkv(n_batch: int, n_head: int, n_ctx: int, d_qk: int, d_v: int):
    qk_shape = (n_batch, n_head, n_ctx, d_qk)
    v_shape = (n_batch, n_head, n_ctx, d_qk)
    return (torch.randn(qk_shape), torch.randn(qk_shape), torch.randn(v_shape))

n_batch = 1
n_head = 1
local_bm = torch.nn.attention.flex_attention.create_block_mask(
    sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
)

flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention)
flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)

Here is the error we get:

E0211 21:13:34.994000 1581518 subproc_pool.py:321] Error in subprocess
E0211 21:13:34.994000 1581518 subproc_pool.py:321] concurrent.futures.process._RemoteTraceback:
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     r = call_item.fn(*call_item.args, **call_item.kwargs)
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 340, in do_job
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return pickler.dumps(result)
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 100, in dumps
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] The above exception was the direct cause of the following exception:
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 319, in callback
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     result = future.result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return self.__get_result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     raise self._exception
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
W0211 21:13:34.996000 1581373 pytorch/torch/_inductor/utils.py:875] [0/0] on error, temporary cache dir kept at /tmp/torchinductor_ubuntu/tmpkwuio_wu
Traceback (most recent call last):
  File "/home/ubuntu/./test.py", line 28, in <module>
    flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)
  File "/home/ubuntu/pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1487, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1466, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_dynamo.py", line 131, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/__init__.py", line 2339, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2163, in compile_fx
    return aot_autograd(
  File "/home/ubuntu/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1168, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/ubuntu/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 205, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2038, in fw_compiler_base
    return inner_compile(
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 623, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_aot.py", line 104, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 727, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1402, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1122, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 1990, in compile_to_module
    return self._compile_to_module()
  File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 2032, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 2758, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/ubuntu/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_ubuntu/tmpkwuio_wu/2c/c2cwsb3k4rlb6akooercw4u4bjrnkofn6xx5cavzkj2swf2iyiii.py", line 552, in <module>
    async_compile.wait(globals())
  File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 421, in wait
    scope[key] = result.result()
  File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 3237, in result
    return self.result_fn()
  File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 311, in get_result
    kernel = task.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

We did find that sometimes the function does get cached and after that we don't see the bug, so you might want to run the reproducer with TORCHINDUCTOR_FORCE_DISABLE_CACHES=1.

Originally posted by @saagarjha in #146417 (comment)

Metadata

Metadata

Assignees

Labels

actionabletriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions