KEMBAR78
Parallelize epilogue/prologue benchmarking by eellison · Pull Request #143408 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Dec 17, 2024

Stack from ghstack (oldest at bottom):

When we attempt prologue or epilogue fusion with a TritonTemplate, we benchmark it at compile time in order to determine profitability. This avoids slowdowns/register spilling, and allows us to pick fusion when a base triton template is slower than cublas but faster when considering an epilogue. However, that fused benchmarking does not do the same async compilation as we do for the base TritonTemplate. The Base TritonTemplate is async compiled during lowering, then later waited on and benchmarked.

This PR extends a similar process to benchmarking fused TritonTemplates in the scheduler. We keep a list of pending fusions which have async compilations. And we resolve any pending fusions a node is in prior to attempting to fuse it with any other node.

Initially, I saw some slowdowns with this because we kick off async compilations of identical fusions in parallel. To address this I added source code caching at the async_compile level (we also already cache benchmark runs, but that would not happen in parallel).

Compilation speedups:

image

This also should let us be a bit more aggressive with either configs, or benchmarking other fusions which are hard to determine profitability of.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143408

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 04386ff with merge base c986eba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 17, 2024
ghstack-source-id: f3c2436
Pull Request resolved: #143408
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 18, 2024
ghstack-source-id: b43c825
Pull Request resolved: #143408
[ghstack-poisoned]
[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 18, 2024
ghstack-source-id: 5c85cf5
Pull Request resolved: #143408
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 27, 2024
ghstack-source-id: 3d0ea3f
Pull Request resolved: #143408
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@eellison eellison marked this pull request as ready for review January 7, 2025 05:18
@eellison eellison changed the title [WIP] parallelize epilogue/prologue benchmarking Parallelize epilogue/prologue benchmarking Jan 7, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 01:28 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 01:28 Inactive
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 16:32 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 16:32 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 16:32 Inactive
pytorchmergebot pushed a commit that referenced this pull request Jan 28, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

Previously, we would finalize the config of a triton template after its first fusion. this maintains multiple configs, in case we epilogue fuse, then prologue fuse, and prologue fusion has a new better config.

Pull Request resolved: #145103
Approved by: https://github.com/jansel, https://github.com/shunting314
ghstack dependencies: #143408
jamesjwu added a commit that referenced this pull request Feb 11, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Feb 11, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

ghstack-source-id: 265875622
Pull Request resolved: #146925
jamesjwu added a commit that referenced this pull request Feb 11, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. 

Found this bug testing internal inference models. 

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Feb 11, 2025
…h inductor compile"


Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. 

Found this bug testing internal inference models. 

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Feb 11, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. 

Found this bug testing internal inference models. 

This does not remove the caching support for eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Feb 11, 2025
Pull Request resolved: #146925

Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. same source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future again.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408
ghstack-source-id: 265898129

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!
pytorchmergebot pushed a commit that referenced this pull request Feb 12, 2025
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in :

```
compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn)
# run compiled_1
out_compiled = compiled_1
compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2)
```

Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher.

Found this bug testing internal inference models.

This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: #143408

Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)!

Pull Request resolved: #146925
Approved by: https://github.com/laithsakka, https://github.com/jansel
ghstack dependencies: #146417
desai0007 pushed a commit to desai0007/test-repo-pytorch that referenced this pull request Feb 26, 2025
@github-actions github-actions bot deleted the gh/eellison/745/head branch February 28, 2025 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants