-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Dynamo tracing time is:
- 70 seconds on 200 param Adam
- 162 seconds on 1000 param SGD
As identified in #110353 (comment), this is due to dynamo needing to trace an expensive for loop.
If instead this for loop can be written in a way that can be easily traced (e.g. by tracing a map over a single lambda/similar to foreach over the optimizer main loop) then we are likely to speedup compilation times across all optimizers by a significant factor.
Example here:
Line 39 in 31d6358
| for p in group['params']: |
Example dynamo logs (which trace the same computation ad infinitum):
[2023-10-04 08:14:53,891] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,891] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,892] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST grad [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,895] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,896] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST grad [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST grad [ListIteratorVariable()]
[2023-10-04 08:14:53,900] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR is_sparse [ListIteratorVariable(), TensorVariable()]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE YIELD_VALUE None [ListIteratorVariable(), ConstantVariable(bool)]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE POP_TOP None [ListIteratorVariable(), ConstantVariable(NoneType)]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE JUMP_ABSOLUTE 4 [ListIteratorVariable()]
[2023-10-04 08:14:53,901] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE FOR_ITER 18 [ListIteratorVariable()]
Repro
import time
import torch
from torch.optim import Adam, SGD
def compile_opt(opt_compiled):
torch._dynamo.eval_frame.TorchPatcher.patch()
step_fn = opt_compiled.step.__wrapped__
def fn():
step_fn(opt_compiled)
return torch.compile(fn, backend="inductor", fullgraph=True)
optim_cls = SGD
NUM_PARAMS = 1000
kwargs = { "lr": 0.01, "foreach": True }
torch._dynamo.reset()
# torch._inductor.metrics.reset()
input = torch.ones([10, 10], device="cuda:0")
model = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(NUM_PARAMS)]
)
input = torch.ones([10, 10], device="cuda:0")
model(input).sum().backward()
opt_compiled = optim_cls(model.parameters(), **kwargs)
compiled_step = compile_opt(opt_compiled)
with torch.set_grad_enabled(False):
start_time = time.time()
compiled_step()
print("compile opt took: %s seconds", time.time() - start_time)Versions
main
cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @Xia-Weiwen @aakhundov