KEMBAR78
Performance Regression if not release GIL in cpp wrapper · Issue #123517 · pytorch/pytorch · GitHub
Skip to content

Performance Regression if not release GIL in cpp wrapper #123517

@zhuhaozhe

Description

@zhuhaozhe

🐛 Describe the bug

After #122554, the gil will not be released.
When use Throughputbenchmark to validate model performance, the GIL will make the benchmark very slow.

image

Minified repro

# bench_gil
import torch
from torch._inductor import config as inductor_config
inductor_config.cpp_wrapper = True

class SimpleM(torch.nn.Module):
    def __init__(self):
        super(SimpleM, self).__init__()
        self.linear1 = torch.nn.Linear(100, 100)
        self.linear2 = torch.nn.Linear(100, 100)


    def forward(self, x, y):
        return self.linear1(x) + self.linear1(y)

from torch.utils import ThroughputBenchmark
model = torch.compile(SimpleM().bfloat16())
x1 = torch.randn(100, 100).bfloat16()
x2 = torch.randn(100, 100).bfloat16()
with torch.no_grad():
    y = model(x1, x2)
    y = model(x1, x2)

bench = ThroughputBenchmark(model)
bench.add_input(x1, x2)
with torch.no_grad():
    stats = bench.benchmark(
        num_calling_threads=24,
        num_warmup_iters=100,
        num_iters=2400,
    )
print(stats)
TORCH_COMPILE_DEBUG=1 OMP_NUM_THREADS=1 numactl -C 0-23 -m 0 python bench_gil.py

output

3ms->6.8ms on my system

Versions

After 537cd66#diff-050ffbf46a890cee19edf85f989025805a1b5b20b26dc5cbe719f9539051b6bb.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @desertfire @chenyang78

Metadata

Metadata

Assignees

Labels

module: aotinductoraot inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions