KEMBAR78
[inductor] multi-gpu max autotuning by shunting314 · Pull Request #96807 · pytorch/pytorch · GitHub
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def mm_plus_mm(a, b, c, d):
d = torch.randn(k, n).cuda()

with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"ignore_max_autotune_cache": True,
}
):
torch.compile(mm_plus_mm)(a, b, c, d)

Expand Down
174 changes: 122 additions & 52 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import atexit
import dataclasses
import queue
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import torch
from torch import multiprocessing
from torch._dynamo.testing import rand_strided

from torch._inductor import ir
from torch._inductor import config, ir
from torch._inductor.codecache import PyCodeCache

from torch._inductor.select_algorithm import ChoiceCaller
from .utils import do_bench
from .virtualized import V

DEBUG = False
EXIT_HANDLER_REGISTERED = False

# Used to synchronize between parent and child processes
class Ping:
Expand All @@ -32,10 +31,12 @@ class TuningProcess:
process: multiprocessing.Process = None
request_queue: multiprocessing.Queue = None
response_queue: multiprocessing.Queue = None
dev_id: Optional[int] = None

@staticmethod
def process_main(request_queue, response_queue):
print("enter child process main")
def process_main(request_queue, response_queue, dev_id):
print(f"enter child process main for dev {dev_id}")
torch.cuda.set_device(dev_id)
while True:
obj = request_queue.get()

Expand All @@ -58,13 +59,17 @@ def valid(self):
def clear(self):
self.process = self.request_queue = self.response_queue = None

def initialize(self):
def initialize(self, dev_id=None):
"""
Create child process, request/response queues and do the warm up.
"""
if self.valid():
return

if dev_id is not None:
self.dev_id = dev_id

assert self.dev_id is not None and self.dev_id >= 0
# cuda runtime does not work with "fork", use "spawn" to start processes.
ctx = multiprocessing.get_context("spawn")
self.request_queue = ctx.Queue()
Expand All @@ -75,19 +80,11 @@ def initialize(self):
args=(
self.request_queue,
self.response_queue,
self.dev_id,
),
)
self.process.start()

# register the exit handler for the parent process so it will terminate
# the child processes
global EXIT_HANDLER_REGISTERED
if not EXIT_HANDLER_REGISTERED:
EXIT_HANDLER_REGISTERED = True
import atexit

atexit.register(lambda: self.terminate())

# wait for the initialization to be done
self.request_queue.put(Ping())
resp = self.response_queue.get()
Expand All @@ -98,8 +95,58 @@ def terminate(self):
self.request_queue.put(None)
self.process.join()

def __hash__(self):
return id(self)


class TuningProcessPool:
"""
Tuning process pool maintaining one process for each GPU. Recreate crashed
process.
"""

def __init__(self):
self.avail_procs = set()
self.all_procs = []

# register the exit handler for the parent process so it will terminate
# the child processes
atexit.register(lambda: self.teardown())

def initialize(self):
"""
Not putting this in __init__ so we don't need create subprocess when
importing the module.
"""
if len(self.all_procs) > 0: # already initialized
return

ngpu = torch.cuda.device_count()
assert ngpu > 0
print(f"Createing {ngpu} autotuning sub process one for each GPU")
self.all_procs = [TuningProcess() for _ in range(ngpu)]
for dev_id in range(ngpu):
self.all_procs[dev_id].initialize(dev_id)

self.avail_procs = set(self.all_procs)

def has_avail_proc(self):
return len(self.avail_procs) > 0

def allocate_proc(self):
assert self.has_avail_proc()
return self.avail_procs.pop()

def return_proc(self, proc):
self.avail_procs.add(proc)

tuning_process = TuningProcess()
def teardown(self):
for proc in self.all_procs:
proc.terminate()


if config.autotune_in_subproc:
tuning_process_pool = TuningProcessPool()


@dataclasses.dataclass
Expand Down Expand Up @@ -131,7 +178,8 @@ def to_tensor(self) -> torch.Tensor:
return rand_strided(
self.sizes,
self.strides,
device=self.device,
# don't use self.device since we may benchmark on diferent GPU
device="cuda",
dtype=self.dtype,
extra_size=self.offset,
)
Expand Down Expand Up @@ -190,42 +238,64 @@ def worker():
if DEBUG:
bench_elapse = time.time() - start_ts
print(
f"InChidProcess {self.module_cache_key}: load {load_elapse}, "
f"InChidProcess-{torch.cuda.current_device()} {self.module_cache_key}: load {load_elapse}, "
+ f"create tensor {create_tensor_elapse}, bench {bench_elapse}"
)
return out


def benchmark_in_sub_process(
choice: ChoiceCaller,
) -> float:
"""
Do benchmarking in subprocess and return the perf number (latency).
"""
assert choice.bmreq is not None
tuning_process.initialize()
assert tuning_process.valid()

tuning_process.request_queue.put(choice.bmreq)

while True:
try:
timing = tuning_process.response_queue.get(timeout=1.0)
except queue.Empty:
status = tuning_process.process.exitcode
if status is None:
# child process is still running
continue
# child process fail
assert status != 0

warnings.warn(
f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long
)

tuning_process.clear()

# return a large value to this choice will be ignored
return float("inf")

return timing
def benchmark_in_sub_process(choices):
timings = {}
if len(choices) == 0:
return timings

if DEBUG:
print(f"Tuning {len(choices)} choices in sub processes")

pending_tasks = {} # map choice to proc
reqlist = [choice.bmreq for choice in choices]
nextreqidx = 0
while nextreqidx < len(reqlist) or len(pending_tasks) > 0:
while nextreqidx < len(reqlist) and tuning_process_pool.has_avail_proc():
proc = tuning_process_pool.allocate_proc()
bmreq = reqlist[nextreqidx]

proc.request_queue.put(bmreq)
pending_tasks[choices[nextreqidx]] = proc
nextreqidx += 1

for choice, proc in pending_tasks.items():
try:
# small timeout so the parent process does not stuck too long if
# the child process is still busy doing its work.
timing = proc.response_queue.get(timeout=0.001)
except queue.Empty:
status = proc.process.exitcode
if status is None:
# still running
continue
# otherwise a crash happens
assert (
status != 0
), f"Child process should be crashed but get status code {status}"
warnings.warn(
f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long
)

timing = float("inf")

# must reinitialize proc
proc.clear()
proc.initialize()

# fall through

timings[choice] = timing
tuning_process_pool.return_proc(proc)
pending_tasks = {
choice: proc
for choice, proc in pending_tasks.items()
if choice not in timings
}

return timings
4 changes: 3 additions & 1 deletion torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def lookup(

def check_cache(cache, callback=None):
"""Check if `cache` contains data for all the choices"""
if config.ignore_max_autotune_cache:
return False
hit = True
for choice in choices:
choice_hash = choice.hash_key()
Expand All @@ -152,8 +154,8 @@ def check_cache(cache, callback=None):
self.get_global_cache(), callback=gc_log
):
# re-benchmark everything to try to get consistent numbers from the same machine
timings.update(benchmark(choices))
for choice in choices:
timings[choice] = benchmark(choice)
local_cache.setdefault(name, {})
local_cache[name].setdefault(inputs, {})
local_cache[name][inputs][choice.hash_key()] = timings[choice]
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
# enable slow autotuning passes to select gemm algorithms
max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"

ignore_max_autotune_cache = (
os.environ.get("TORCHINDUCTOR_IGNORE_MAX_AUTOTUNE_CACHE") == "1"
)

# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"

Expand Down
Loading