KEMBAR78
[inductor] Parallelize Max Autotune step 2: Use all GPUs by masnesral · Pull Request #107983 · pytorch/pytorch · GitHub
Skip to content

Conversation

@masnesral
Copy link
Contributor

@masnesral masnesral commented Aug 25, 2023

Stack from ghstack (oldest at bottom):

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
python test/inductor/test_max_autotune.py
TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart
TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107983

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit ff6a940 with merge base 0f88d93 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

masnesral added a commit that referenced this pull request Aug 25, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: cdf42ba
Pull Request resolved: #107983
@masnesral masnesral requested a review from eellison August 25, 2023 22:30
…l GPUs"

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Aug 26, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: 8dcf2be
Pull Request resolved: #107983
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you have test failures. We might need to turn this off by default for now, or only on during aot_compilation. can do that in subsequent pr

…une step 2: Use all GPUs"

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…all GPUs"

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Aug 29, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: b5cfd23
Pull Request resolved: #107983
@@ -1,10 +1,15 @@
from __future__ import annotations
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support the type hint Queue[TuningProcess] in the 3.8 builds since apparently a subscriptable Queue type was added later.

env=env,
)

# register the exit handler for the parent process so it will terminate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This moves to the new TuningProcessPool class below.

print(f"{len(choices)} tuning requests:")

def benchmark_in_current_process(choice):
if DEBUG:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff looks a little weird here. debug_str() definition moved up from below and this was the first few lines of the previous impl. But the start_ts was never actually used, so I removed these two lines.

torch.cuda.synchronize() # shake out any CUDA errors
return result

def benchmark_in_sub_process(choice):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff looks a little weird here too. The benchmark_in_sub_process impl is below.

@masnesral masnesral requested a review from eellison August 30, 2023 15:14
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Aug 30, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: 4698d74
Pull Request resolved: #107983
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool ! tagging @shunting314 who did a lot of the initial impl

return

count = count or torch.cuda.device_count()
assert count > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should set count so it is no greater than torch.cuda.device_count()

@eellison eellison requested a review from shunting314 September 1, 2023 00:57
@shunting314
Copy link
Contributor

Can you also share how much speedup we can get by leveraging multi-GPU. I actually have a similar PR sometime back https://github.com/pytorch/pytorch/pull/96807/files to extend the sub-process autotuning to multi-process/multi-GPU autotuning . We didn't push forward at that time because people thought a global cache is more promising.


def benchmark_in_current_process(options):
timings = {}
for choice in options:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we uniformly call it 'choice' and 'choices'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shunting314, sure I'll change. That's what I had initially, but as it turns out this is a nested function and the outer method also has a param called 'choices'. So I didn't love the ambiguity. It would be easy to misspell 'options', for example, and be none-the-wiser because we'd silently reference the 'options' from the outer scope. I guess this is one of the reasons why I personally don't love nested functions.

@masnesral
Copy link
Contributor Author

Can you also share how much speedup we can get by leveraging multi-GPU.

@shunting314: here's a side-by-side comparison using the hf_Bart benchmark:

							      >	+ TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1
+ TORCHINDUCTOR_MAX_AUTOTUNE=1					+ TORCHINDUCTOR_MAX_AUTOTUNE=1
+ python benchmarks/dynamo/torchbench.py --device cuda --perf	+ python benchmarks/dynamo/torchbench.py --device cuda --perf
loading model: 0it [00:09, ?it/s]				loading model: 0it [00:09, ?it/s]
cuda eval  hf_Bart						cuda eval  hf_Bart
AUTOTUNE mm(512x768, 768x768)					AUTOTUNE mm(512x768, 768x768)
  mm 0.0148 ms 100.0%					      |	  mm 0.0141 ms 100.0%
  triton_mm_8 0.0154 ms 96.5%				      |	  triton_mm_8 0.0141 ms 99.5%
  triton_mm_3 0.0164 ms 90.8%				      |	  triton_mm_3 0.0161 ms 87.3%
  triton_mm_4 0.0171 ms 86.9%				      |	  triton_mm_4 0.0163 ms 86.4%
  triton_mm_9 0.0173 ms 85.6%				      |	  triton_mm_9 0.0167 ms 84.5%
  triton_mm_6 0.0175 ms 84.7%				      |	  triton_mm_6 0.0175 ms 80.3%
  triton_mm_5 0.0181 ms 82.2%				      |	  triton_mm_5 0.0179 ms 78.5%
  triton_mm_1 0.0197 ms 75.4%				      |	  triton_mm_1 0.0203 ms 69.5%
  triton_mm_2 0.0204 ms 72.7%				      |	  triton_mm_2 0.0203 ms 69.3%
  triton_mm_0 0.0300 ms 49.6%				      |	  triton_mm_0 0.0301 ms 46.8%
SingleProcess AUTOTUNE takes 5.1835 seconds		      |	SubProcess AUTOTUNE takes 4.7125 seconds
AUTOTUNE bmm(12x512x64, 12x64x512)				AUTOTUNE bmm(12x512x64, 12x64x512)
  triton_bmm_26 0.0138 ms 100.0%			      |	  triton_bmm_31 0.0132 ms 100.0%
  triton_bmm_25 0.0138 ms 99.5%				      |	  triton_bmm_26 0.0133 ms 99.8%
  triton_bmm_28 0.0146 ms 94.1%				      |	  triton_bmm_25 0.0136 ms 97.2%
  triton_bmm_31 0.0146 ms 94.1%				      |	  triton_bmm_32 0.0137 ms 96.5%
  triton_bmm_27 0.0147 ms 93.5%				      |	  triton_bmm_28 0.0139 ms 95.2%
  triton_bmm_32 0.0147 ms 93.5%				      |	  bmm 0.0148 ms 89.8%
  bmm 0.0148 ms 93.2%					      |	  triton_bmm_27 0.0154 ms 86.1%
  triton_bmm_24 0.0151 ms 90.9%				      |	  triton_bmm_24 0.0156 ms 84.7%
  triton_bmm_34 0.0164 ms 84.0%				      |	  triton_bmm_34 0.0169 ms 78.6%
  triton_bmm_33 0.0182 ms 75.4%				      |	  triton_bmm_33 0.0173 ms 76.5%
SingleProcess AUTOTUNE takes 4.6074 seconds		      |	SubProcess AUTOTUNE takes 1.0350 seconds
AUTOTUNE bmm(12x512x512, 12x512x64)				AUTOTUNE bmm(12x512x512, 12x512x64)
  triton_bmm_51 0.0154 ms 100.0%			      |	  triton_bmm_56 0.0141 ms 100.0%
  triton_bmm_56 0.0154 ms 99.6%				      |	  triton_bmm_51 0.0147 ms 96.3%
  triton_bmm_57 0.0155 ms 99.0%				      |	  triton_bmm_57 0.0154 ms 91.7%
  triton_bmm_54 0.0161 ms 95.4%				      |	  triton_bmm_53 0.0157 ms 90.0%
  triton_bmm_53 0.0162 ms 95.0%				      |	  triton_bmm_49 0.0160 ms 88.0%
  triton_bmm_52 0.0163 ms 94.5%				      |	  triton_bmm_52 0.0162 ms 87.0%
  bmm 0.0166 ms 92.3%					      |	  triton_bmm_54 0.0163 ms 86.6%
  triton_bmm_49 0.0174 ms 88.4%				      |	  bmm 0.0166 ms 85.0%
  triton_bmm_50 0.0175 ms 87.8%				      |	  triton_bmm_50 0.0179 ms 78.7%
  triton_bmm_48 0.0244 ms 62.8%				      |	  triton_bmm_48 0.0238 ms 59.3%
SingleProcess AUTOTUNE takes 4.2848 seconds		      |	SubProcess AUTOTUNE takes 0.7750 seconds
AUTOTUNE mm(512x768, 768x3072)					AUTOTUNE mm(512x768, 768x3072)
  mm 0.0214 ms 100.0%					      |	  mm 0.0206 ms 100.0%
  triton_mm_74 0.0230 ms 92.8%				      |	  triton_mm_74 0.0230 ms 89.9%
  triton_mm_73 0.0231 ms 92.6%				      |	  triton_mm_73 0.0238 ms 86.6%
  triton_mm_76 0.0251 ms 85.1%				      |	  triton_mm_76 0.0244 ms 84.6%
  triton_mm_75 0.0252 ms 85.0%				      |	  triton_mm_75 0.0246 ms 83.8%
  triton_mm_80 0.0296 ms 72.2%				      |	  triton_mm_80 0.0292 ms 70.6%
  triton_mm_72 0.0338 ms 63.3%				      |	  triton_mm_72 0.0328 ms 62.9%
  triton_mm_79 0.0377 ms 56.7%				      |	  triton_mm_79 0.0383 ms 53.9%
  triton_mm_82 0.0420 ms 50.9%				      |	  triton_mm_82 0.0416 ms 49.6%
  triton_mm_81 0.0465 ms 46.0%				      |	  triton_mm_81 0.0456 ms 45.3%
SingleProcess AUTOTUNE takes 4.7278 seconds		      |	SubProcess AUTOTUNE takes 0.8443 seconds
AUTOTUNE mm(512x3072, 3072x768)					AUTOTUNE mm(512x3072, 3072x768)
  mm 0.0277 ms 100.0%					      |	  mm 0.0283 ms 100.0%
  triton_mm_92 0.0358 ms 77.4%				      |	  triton_mm_92 0.0363 ms 78.0%
  triton_mm_87 0.0420 ms 66.1%				      |	  triton_mm_88 0.0414 ms 68.4%
  triton_mm_88 0.0420 ms 66.1%				      |	  triton_mm_87 0.0419 ms 67.6%
  triton_mm_89 0.0440 ms 63.0%				      |	  triton_mm_89 0.0440 ms 64.3%
  triton_mm_90 0.0451 ms 61.6%				      |	  triton_mm_90 0.0445 ms 63.6%
  triton_mm_93 0.0461 ms 60.2%				      |	  triton_mm_93 0.0462 ms 61.3%
  triton_mm_86 0.0563 ms 49.3%				      |	  triton_mm_86 0.0566 ms 50.1%
  triton_mm_85 0.0567 ms 49.0%				      |	  triton_mm_85 0.0566 ms 50.0%
  triton_mm_84 0.0708 ms 39.2%				      |	  triton_mm_84 0.0693 ms 40.9%
SingleProcess AUTOTUNE takes 4.5848 seconds		      |	SubProcess AUTOTUNE takes 1.1149 seconds
AUTOTUNE mm(512x768, 768x50265)					AUTOTUNE mm(512x768, 768x50265)
  triton_mm_1585 0.2616 ms 100.0%			      |	  triton_mm_1585 0.2570 ms 100.0%
  triton_mm_1586 0.2622 ms 99.8%			      |	  triton_mm_1586 0.2654 ms 96.8%
  triton_mm_1591 0.3036 ms 86.2%			      |	  triton_mm_1591 0.3005 ms 85.5%
  triton_mm_1587 0.3041 ms 86.0%			      |	  triton_mm_1587 0.3045 ms 84.4%
  triton_mm_1588 0.3084 ms 84.8%			      |	  triton_mm_1588 0.3071 ms 83.7%
  triton_mm_1584 0.3270 ms 80.0%			      |	  triton_mm_1584 0.3268 ms 78.6%
  triton_mm_1592 0.4618 ms 56.6%			      |	  triton_mm_1592 0.4638 ms 55.4%
  triton_mm_1594 0.4884 ms 53.6%			      |	  triton_mm_1594 0.4878 ms 52.7%
  triton_mm_1593 0.5923 ms 44.2%			      |	  triton_mm_1593 0.5705 ms 45.0%
  triton_mm_1590 0.6220 ms 42.1%			      |	  triton_mm_1590 0.6091 ms 42.2%
SingleProcess AUTOTUNE takes 5.7595 seconds		      |	SubProcess AUTOTUNE takes 1.0035 seconds
running benchmark: 100%|█████████████████████████████████████ |	running benchmark: 100%|█████████████████████████████████████
2.971x							      |	2.950x

… all GPUs"

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

[ghstack-poisoned]
@masnesral masnesral added the topic: not user facing topic category label Sep 8, 2023
@masnesral
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 8, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x c5352a8bf95ac9df4b74bbcc19461ae9e5bcc1e7 returned non-zero exit code 1

Auto-merging torch/_inductor/codecache.py
CONFLICT (content): Merge conflict in torch/_inductor/codecache.py
Auto-merging torch/_inductor/config.py
Auto-merging torch/_inductor/select_algorithm.py
error: could not apply c5352a8bf95... [inductor] Parallelize Max Autotune step 2: Use all GPUs
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
Details for Dev Infra team Raised by workflow job

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Sep 8, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: b7fb445
Pull Request resolved: #107983
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Sep 9, 2023
Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: 0743a2d
Pull Request resolved: #107983
@masnesral
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# Launch the child processes and push a msg to "warm up"
self.processes = Queue()
for device in range(count):
p = TuningProcess(device=device if config.autotune_multi_device else None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masnesral @eellison I believe, there is a subtlety here worth revisiting.

Above we iterate over the range(count) and pass the device counter variable downstream, which is then set as the value of the CUDA_VISIBLE_DEVICES env variable when launching the subprocess. Now if the user has set CUDA_VISIBLE_DEVICES=3,4,5 for the parent process, the torch.cuda.device_count() above will be 3, and we'll initialize the 3 subprocesses with CUDA_VISIBLE_DEVICES=0, =1, and =2. I guess, this is not what the user would expect?

In AITemplate, we actually parse the CUDA_VISIBLE_DEVICES env variable (fetched by target.dev_select_flag() here) to detect the user-specified devices. Perhaps we should do something similar here, too? Or maybe there is a torch.cuda.<something> to detect the visible device IDs (didn't find from a quick check)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aakhundov, ah sure. That sounds like a better behavior. Lemme work on a follow-up.

Comment on lines +241 to +243
process.put(choice.bmreq)
try:
return process.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a timeout? Can this get stuck forever otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aakhundov, I'm not sure. Do you have a scenario in mind that would lead to a sub-process hanging? I'd love a repro to play around with. I'm new here and my naive assumption was that these kernels are fairly constrained and we wouldn't be concerned about arbitrary code. If the process crashes or exits unexpectedly, then we handle that scenario.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an exact scenario in mind. Just, given that we're relying on STDIN- / STDOUT-based communication, assuming that something can go wrong there leading to indefinitely long waiting. In AIT, we have a timeout for the whole profiling session (which, to be fair, we had to increase a couple of times due to longer e2e profiling that was initially expected).

@masnesral
Copy link
Contributor Author

@pytorchbot revert -m "fbcode failures"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@masnesral
Copy link
Contributor Author

@pytorchbot revert -m "fbcode failures" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@masnesral your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 12, 2023
@facebook-github-bot facebook-github-bot deleted the gh/masnesral/3/head branch September 14, 2023 14:22
michiboo pushed a commit to michiboo/pytorch that referenced this pull request Sep 17, 2023
)

Summary: Step 2 in revamping subprocess autotune to support multiple GPUs: use a pool of subprocesses and distribute benchmark calls across them.

Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

Pull Request resolved: pytorch#107983
Approved by: https://github.com/eellison, https://github.com/shunting314
ghstack dependencies: pytorch#107982
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants