KEMBAR78
[inductor] Parallelize Max Autotune step 1: refactor autotune_process by masnesral · Pull Request #109126 · pytorch/pytorch · GitHub
Skip to content

Conversation

@masnesral
Copy link
Contributor

@masnesral masnesral commented Sep 12, 2023

Stack from ghstack (oldest at bottom):

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs. This diff just does some refactoring to autotune_process.py in order to prepare for the next diff:

  • Move all logic for managing the sub-process (like detecting sub-process crashes) into the TuningProcess class.
  • Use log.debug statements instead of print statements

Test Plan: python test/inductor/test_max_autotune.py

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs. This diff just does some refactoring to autotune_process.py in order to prepare for the next diff:
* Move all logic for managing the sub-process (like detecting sub-process crashes) into the TuningProcess class.
* Use log.debug statements instead of print statements

Test Plan: python test/inductor/test_max_autotune.py

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109126

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 143ebfa with merge base 264f1e7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…une_process"

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs. This diff just does some refactoring to autotune_process.py in order to prepare for the next diff:
* Move all logic for managing the sub-process (like detecting sub-process crashes) into the TuningProcess class.
* Use log.debug statements instead of print statements

Test Plan: python test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…une_process"

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs. This diff just does some refactoring to autotune_process.py in order to prepare for the next diff:
* Move all logic for managing the sub-process (like detecting sub-process crashes) into the TuningProcess class.
* Use log.debug statements instead of print statements

Test Plan: python test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@masnesral masnesral added the topic: not user facing topic category label Sep 12, 2023
@masnesral masnesral marked this pull request as ready for review September 12, 2023 19:22
@masnesral
Copy link
Contributor Author

@shunting314 FYI this is a redo of #107982. We had to revert that change because it didn't play well in fbcode. In fbcode, everything is a .xar file and we got feedback that we can't necessarily guarantee the proper environment for a subprocess started via Popen. So this change goes back to using multiprocessing and multiprocessing queues. This change just does some reorg to make the next diff in the stack a little easier to review.

…une_process"

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs. This diff just does some refactoring to autotune_process.py in order to prepare for the next diff:
* Move all logic for managing the sub-process (like detecting sub-process crashes) into the TuningProcess class.
* Use log.debug statements instead of print statements

Test Plan: python test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2023
Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

Pull Request resolved: #109127
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #109126
@facebook-github-bot facebook-github-bot deleted the gh/masnesral/9/head branch September 17, 2023 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants