KEMBAR78
Only call triton in worker process, kick off worker processes earlier, during inductor codegen by jamesjwu · Pull Request #146417 · pytorch/pytorch · GitHub
Skip to content

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Feb 4, 2025

Stack from ghstack (oldest at bottom):

Big idea

This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.

Implementation Overview

In total, the diff does the following:

  • Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
  • Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
  • Extend @eellison's future cache to a class, mostly as a refactor
  • Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
    async_compile.triton call that occurs after codegen to cache hit on cold start.
    In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
    Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

Differential Revision: D69123174

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang @aakhundov

### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146417

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit c9299aa with merge base 2a55311 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

jamesjwu added a commit that referenced this pull request Feb 4, 2025
### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

ghstack-source-id: 264615152
Pull Request resolved: #146417
### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

@jamesjwu jamesjwu added the topic: not user facing topic category label Feb 4, 2025
### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

jamesjwu added a commit that referenced this pull request Feb 5, 2025
Pull Request resolved: #146417

 ### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process
ghstack-source-id: 264758622

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)
### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

[ghstack-poisoned]
### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

### Big idea
This PR extends #144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

### Can we split the diff for easier review?
It's best if this diff lands atomically with all of these changes, as doing the ahead of time codegen compile is only performant if we replace TritonFuture with LambdaFuture(as we don't need to load the triton kernel on the main process). However, I've made a diff stack for easier reviewing here:
- D69070048 - Run async_compile.triton ahead of time in Scheduler.codegen
- D68633454 - Only call triton in worker process

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69123174

@jamesjwu
Copy link
Contributor Author

@pytorchbot merge -f

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 11, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot merge: error: argument -f/--force: expected one argument

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Try @pytorchbot --help for more info.

@huydhn
Copy link
Contributor

huydhn commented Feb 11, 2025

@pytorchbot merge -f 'Pending unstable ROCm jobs'

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@saagarjha
Copy link
Contributor

Hi! I think this PR breaks our code. Here's a reproducer:

#!/usr/bin/env python3

import torch
import torch.nn.attention.flex_attention

torch.set_default_device("cuda")

N_CTX = 4096
SLIDING_WINDOW = 128

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx < SLIDING_WINDOW
    return causal_mask & window_mask

def rand_qkv(n_batch: int, n_head: int, n_ctx: int, d_qk: int, d_v: int):
    qk_shape = (n_batch, n_head, n_ctx, d_qk)
    v_shape = (n_batch, n_head, n_ctx, d_qk)
    return (torch.randn(qk_shape), torch.randn(qk_shape), torch.randn(v_shape))

n_batch = 1
n_head = 1
local_bm = torch.nn.attention.flex_attention.create_block_mask(
    sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
)

flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention)
flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)

Here is the error we get:

E0211 21:13:34.994000 1581518 subproc_pool.py:321] Error in subprocess
E0211 21:13:34.994000 1581518 subproc_pool.py:321] concurrent.futures.process._RemoteTraceback:
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     r = call_item.fn(*call_item.args, **call_item.kwargs)
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 340, in do_job
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return pickler.dumps(result)
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 100, in dumps
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
E0211 21:13:34.994000 1581518 subproc_pool.py:321] """
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] The above exception was the direct cause of the following exception:
E0211 21:13:34.994000 1581518 subproc_pool.py:321]
E0211 21:13:34.994000 1581518 subproc_pool.py:321] Traceback (most recent call last):
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/home/ubuntu/pytorch/torch/_inductor/compile_worker/subproc_pool.py", line 319, in callback
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     result = future.result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     return self.__get_result()
E0211 21:13:34.994000 1581518 subproc_pool.py:321]   File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
E0211 21:13:34.994000 1581518 subproc_pool.py:321]     raise self._exception
E0211 21:13:34.994000 1581518 subproc_pool.py:321] AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'
W0211 21:13:34.996000 1581373 pytorch/torch/_inductor/utils.py:875] [0/0] on error, temporary cache dir kept at /tmp/torchinductor_ubuntu/tmpkwuio_wu
Traceback (most recent call last):
  File "/home/ubuntu/./test.py", line 28, in <module>
    flex_attention(*rand_qkv(n_batch, n_head, N_CTX, d_qk=16, d_v=16), return_lse=True, block_mask=local_bm)
  File "/home/ubuntu/pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1487, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/ubuntu/pytorch/torch/_dynamo/output_graph.py", line 1466, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_dynamo.py", line 131, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/__init__.py", line 2339, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2163, in compile_fx
    return aot_autograd(
  File "/home/ubuntu/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1168, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/ubuntu/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 205, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/home/ubuntu/pytorch/torch/_functorch/aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 2038, in fw_compiler_base
    return inner_compile(
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 623, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/ubuntu/pytorch/torch/_dynamo/repro/after_aot.py", line 104, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 727, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1402, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/ubuntu/pytorch/torch/_inductor/compile_fx.py", line 1122, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 1990, in compile_to_module
    return self._compile_to_module()
  File "/home/ubuntu/pytorch/torch/_inductor/graph.py", line 2032, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 2758, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/ubuntu/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_ubuntu/tmpkwuio_wu/2c/c2cwsb3k4rlb6akooercw4u4bjrnkofn6xx5cavzkj2swf2iyiii.py", line 552, in <module>
    async_compile.wait(globals())
  File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 421, in wait
    scope[key] = result.result()
  File "/home/ubuntu/pytorch/torch/_inductor/codecache.py", line 3237, in result
    return self.result_fn()
  File "/home/ubuntu/pytorch/torch/_inductor/async_compile.py", line 311, in get_result
    kernel = task.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: Can't pickle local object 'JITFunction.__init__.<locals>.<lambda>'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

We did find that sometimes the function does get cached and after that we don't see the bug, so you might want to run the reproducer with TORCHINDUCTOR_FORCE_DISABLE_CACHES=1.

@jamesjwu
Copy link
Contributor Author

Ah hmm, taking a look now!

@oulgen
Copy link
Contributor

oulgen commented Feb 11, 2025

@saagarjha what’s your triton version? Are you using triton from PyTorch’s bundle?

@jamesjwu
Copy link
Contributor Author

@saagarjha I've run this test for Triton version 3.2 (the one pinned to pytorch's nightly bundle) and TORCHINDUCTOR_FORCE_DISABLE_CACHES=1, and cannot reproduce the issue. I also tried triton 3.1, but it still doesn't seem to reproduce. Could you check your triton version for me?

@saagarjha
Copy link
Contributor

Ah, sorry for not mentioning it. I’m using the latest PyTorch/Triton commits (that is, not the pinned version).

@jamesjwu
Copy link
Contributor Author

Ah, it's possible that there are specific fields in JITFunction in bleeding edge triton that aren't compatible with this PR. I'll look into making the pickling setup a bit less brittle to those cases. As a workaround, you could use the pinned triton version; I'll file an issue around this as it would block us from updating past the current triton version.

@saagarjha
Copy link
Contributor

Yep, that sounds reasonable. We're not actually expecting you to track Triton main so we do have workarounds :) Just thought I'd let you know in advance so you can fix this in advance of bumping your pinned version.

pytorchmergebot pushed a commit that referenced this pull request Feb 12, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher.

Found this bug testing internal inference models.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

Pull Request resolved: #146925
Approved by: https://github.com/laithsakka, https://github.com/jansel
ghstack dependencies: #146417
Jokeren pushed a commit to triton-lang/triton that referenced this pull request Feb 13, 2025
PyTorch issue: pytorch/pytorch#146945

Functionality in PyTorch that started relying on serializability of
`JITFunction`: pytorch/pytorch#146417

I suppose there are different ways to solve this problem, but at least
the current lambdas are not necessary and can be easily rewritten.

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Feb 24, 2025
…#146925)

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher.

Found this bug testing internal inference models.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: pytorch#143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

Pull Request resolved: pytorch#146925
Approved by: https://github.com/laithsakka, https://github.com/jansel
ghstack dependencies: pytorch#146417
desai0007 pushed a commit to desai0007/test-repo-pytorch that referenced this pull request Feb 26, 2025
Pull Request resolved: pytorch/pytorch#146417

 ### Big idea
This PR extends pytorch/pytorch#144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)
ghstack-source-id: 22f0e5a
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…, during inductor codegen (pytorch#146417)

### Big idea
This PR extends pytorch#144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

Pull Request resolved: pytorch#146417
Approved by: https://github.com/jansel
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…#146925)

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher.

Found this bug testing internal inference models.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: pytorch#143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

Pull Request resolved: pytorch#146925
Approved by: https://github.com/laithsakka, https://github.com/jansel
ghstack dependencies: pytorch#146417
@github-actions github-actions bot deleted the gh/jamesjwu/106/head branch March 14, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants