KEMBAR78
Clear CompiledTritonKernel cache after each inductor compile by jamesjwu · Pull Request #146925 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Feb 11, 2025

Stack from ghstack (oldest at bottom):

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher.

Found this bug testing internal inference models.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: D69476856

NOTE FOR REVIEWERS: This PR has internal Meta-specific changes or comments, please review them on Phabricator!

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146925

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Pending, 4 Unrelated Failures

As of commit 6cf05de with merge base 30cbf13 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69476856

jamesjwu added a commit that referenced this pull request Feb 11, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

ghstack-source-id: 265875622
Pull Request resolved: #146925
@jamesjwu jamesjwu added topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request labels Feb 11, 2025
Copy link
Contributor

@laithsakka laithsakka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also if you can followup with a unit test would be great

compiled_graph.post_compile(example_inputs, cudagraphs, constants)

log.debug("FX codegen and compilation took %.3fs", time.time() - start)
# Clear Compiled Triton Kernels per inductor compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment explaining why this is important? emphasize it should be done

@jamesjwu
Copy link
Contributor Author

The mergebase is too old here, will rebase

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. 

Found this bug testing internal inference models. 

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

Having a bit of trouble isolating the exact model that breaks when this happens (not just a simple add_mm or something, a complicated internal model). May land this and then add a unit test

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69476856

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. 

Found this bug testing internal inference models. 

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69476856

jamesjwu added a commit that referenced this pull request Feb 11, 2025
Pull Request resolved: #146925

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408
ghstack-source-id: 265898129

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!
@laithsakka
Copy link
Contributor

Having a bit of trouble isolating the exact model that breaks when this happens (not just a simple add_mm or something, a complicated internal model). May land this and then add a unit test

if its not big trouble i figured that sometime its hard to create unit test repos

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / unit-test / cuda12.4-py3.13-gcc9-sm86 / test (inductor, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@huydhn
Copy link
Contributor

huydhn commented Feb 12, 2025

@pytorchbot merge -i

@huydhn
Copy link
Contributor

huydhn commented Feb 12, 2025

@pytorchbot merge -f 'Bypass ROCm unstable jobs'

1 similar comment
@huydhn
Copy link
Contributor

huydhn commented Feb 12, 2025

@pytorchbot merge -f 'Bypass ROCm unstable jobs'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/jamesjwu/109/head branch March 23, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants