KEMBAR78
[dynamo] add infinite generators `itertools.{count, repeat, cycle}` by jon-chuang · Pull Request #110967 · pytorch/pytorch · GitHub
Skip to content
97 changes: 97 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7578,6 +7578,103 @@ def fn(x):
self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_repeat(self):
counters.clear()

def fn(x):
r = itertools.repeat(100.0)
idx = 0
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_repeat_mutation(self):
counters.clear()

def fn(x):
r = itertools.repeat(x)
idx = 0
for i in r:
x += i
i += 1
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_count(self):
for args in ([], [10], [5, -1]):
counters.clear()

def fn(x):
r = itertools.count(*args)
idx = 0
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_cycle(self):
counters.clear()

def fn(x):
for iterator in (
iter([]),
iter([10, 11.0]),
itertools.repeat(-1, 3),
itertools.count(10),
):
r = itertools.cycle(iterator)
idx = 0
x += 1
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_accumulate_symint_default_sum(self):
# https://github.com/pytorch/pytorch/issues/110287
counters.clear()
Expand Down
12 changes: 6 additions & 6 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,11 +1065,10 @@ def POP_FINALLY(self, inst):

def FOR_ITER(self, inst):
it = self.pop()
if isinstance(it, ListIteratorVariable):
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
self.output.guards.update(it.guards)
try:
val, next_iter = it.next_variables()
self.replace_all(it, next_iter)
val, next_iter = it.next_variables(self)
self.push(next_iter)
self.push(val)
except StopIteration:
Expand Down Expand Up @@ -2559,11 +2558,12 @@ def YIELD_FROM(self, inst):
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
if isinstance(tos, ListIteratorVariable):
if isinstance(
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
):
self.output.guards.update(tos.guards)
try:
val, next_iter = tos.next_variables()
self.replace_all(tos, next_iter)
val, next_iter = tos.next_variables(self)
self.push(val)
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
self.YIELD_VALUE(inst)
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
UserMethodVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,
IteratorVariable,
RepeatIteratorVariable,
)
from .lists import (
BaseListVariable,
ListIteratorVariable,
Expand Down Expand Up @@ -79,6 +85,10 @@
"GetAttrVariable",
"GradModeVariable",
"InspectSignatureVariable",
"IteratorVariable",
"RepeatIteratorVariable",
"CountIteratorVariable",
"CycleIteratorVariable",
"LambdaVariable",
"ListIteratorVariable",
"ListVariable",
Expand Down
13 changes: 10 additions & 3 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,12 @@ def _dyn_proxy(self, tx, *args, **kwargs):
def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs)

if isinstance(obj, variables.IteratorVariable):
# For non-list iterators, we will guard on vars that
# determine the control flow
return obj

# TODO This should probably be treated as a dict, or dicts should also be treated here
if self.fn == set:
cls = SetVariable
Expand Down Expand Up @@ -965,9 +971,10 @@ def call_super(self, tx, a, b):
return variables.SuperVariable(a, b)

def call_next(self, tx, arg):
if isinstance(arg, variables.ListIteratorVariable):
val, next_iter = arg.next_variables()
tx.replace_all(arg, next_iter)
if isinstance(
arg, (variables.ListIteratorVariable, variables.IteratorVariable)
):
val, next_iter = arg.next_variables(tx)
return val
elif isinstance(arg, variables.BaseListVariable):
return arg.items[0].add_options(self, arg)
Expand Down
101 changes: 101 additions & 0 deletions torch/_dynamo/variables/iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
MAX_CYCLE = 3000

from typing import List, Optional

from ..exc import unimplemented

from .base import VariableTracker
from .constant import ConstantVariable


class IteratorVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def next_variables(self, tx):
unimplemented("abstract method, must implement")


class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs):
super().__init__(**kwargs)
self.item = item

# Repeat needs no mutation, clone self
def next_variables(self, tx):
# add_options will clone self.item
return self.item.add_options(self), self


class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs):
super().__init__(**kwargs)
if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item)
if not isinstance(step, VariableTracker):
step = ConstantVariable.create(step)
self.item = item
self.step = step

def next_variables(self, tx):
assert self.mutable_local
next_item = self.item.call_method(tx, "__add__", [self.step], {})
next_iter = self.clone(item=next_item)
tx.replace_all(self, next_iter)
return self.item.add_options(self), next_iter


class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: List[VariableTracker] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
):
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item

def next_variables(self, tx):
assert self.mutable_local

if self.iterator is not None:
try:
new_item, next_inner_iter = self.iterator.next_variables(tx)
tx.replace_all(self.iterator, next_inner_iter)
if len(self.saved) > MAX_CYCLE:
unimplemented(
"input iterator to itertools.cycle has too many items"
)
next_iter = self.clone(
iterator=next_inner_iter,
saved=self.saved + [new_item],
item=new_item,
)

tx.replace_all(self, next_iter)
if self.item is None:
return next_iter.next_variables(tx)
return self.item.add_options(self), next_iter
except StopIteration:
next_iter = self.clone(iterator=None)
# this is redundant as next_iter will do the same
# but we do it anyway for safety
tx.replace_all(self, next_iter)
return next_iter.next_variables(tx)
elif len(self.saved) > 0:
next_iter = self.clone(
saved_index=(self.saved_index + 1) % len(self.saved),
item=self.saved[self.saved_index],
)
tx.replace_all(self, next_iter)
return self.item.add_options(self), next_iter
else:
raise StopIteration
return self.item.add_options(self), next_iter
6 changes: 4 additions & 2 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,18 @@ def __init__(self, items, index: int = 0, **kwargs):
self.items = items
self.index = index

def next_variables(self):
def next_variables(self, tx):
assert self.mutable_local
if self.index >= len(self.items):
raise StopIteration()
return self.items[self.index].add_options(self), ListIteratorVariable(
next_iter = ListIteratorVariable(
self.items,
self.index + 1,
mutable_local=MutableLocal(),
**VariableTracker.propagate([self]),
)
tx.replace_all(self, next_iter)
return self.items[self.index].add_options(self), next_iter

def as_python_constant(self):
if self.index > 0:
Expand Down
14 changes: 9 additions & 5 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,16 +885,20 @@ def wraps(fn):
fn, args=rest_args, keywords=kwargs, **options
)
elif self.value is itertools.repeat:
from .builder import SourcelessBuilder

if len(args) < 2:
# We cannot risk infinite generator being consumed to exhaustion by dynamo
# (i.e. infinite loop)
unimplemented("Infinite repeat is not supported")
return variables.RepeatIteratorVariable(
*args, mutable_local=MutableLocal()
)

from .builder import SourcelessBuilder

return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
else:
try:
path = inspect.getfile(self.value)
Expand Down