KEMBAR78
[dynamo] Slow compile times for optimizers due to for loops · Issue #110506 · pytorch/pytorch · GitHub
Skip to content

[dynamo] Slow compile times for optimizers due to for loops #110506

@jon-chuang

Description

@jon-chuang

🐛 Describe the bug

Dynamo tracing time is:

  • 70 seconds on 200 param Adam
  • 162 seconds on 1000 param SGD

As identified in #110353 (comment), this is due to dynamo needing to trace an expensive for loop.

If instead this for loop can be written in a way that can be easily traced (e.g. by tracing a map over a single lambda/similar to foreach over the optimizer main loop) then we are likely to speedup compilation times across all optimizers by a significant factor.

Example here:

for p in group['params']:

Example dynamo logs (which trace the same computation ad infinitum):

[2023-10-04 08:14:53,891] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,891] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST grad [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST grad [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]

CC: @mlazos @jansel

Repro

import time
import torch
from torch.optim import Adam, SGD

def compile_opt(opt_compiled):
    torch._dynamo.eval_frame.TorchPatcher.patch()

    step_fn = opt_compiled.step.__wrapped__
    def fn():
        step_fn(opt_compiled)

    return torch.compile(fn, backend="inductor", fullgraph=True)

optim_cls = SGD
NUM_PARAMS = 1000
kwargs = { "lr": 0.01, "foreach": True }

torch._dynamo.reset()
# torch._inductor.metrics.reset()
input = torch.ones([10, 10], device="cuda:0")
model = torch.nn.Sequential(
    *[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(NUM_PARAMS)]
)

input = torch.ones([10, 10], device="cuda:0")
model(input).sum().backward()
opt_compiled = optim_cls(model.parameters(), **kwargs)
compiled_step = compile_opt(opt_compiled)

with torch.set_grad_enabled(False):
    start_time = time.time()
    compiled_step()
    print("compile opt took: %s seconds", time.time() - start_time)

Versions

main

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @Xia-Weiwen @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductormodule: optimizerRelated to torch.optimoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions