KEMBAR78
[dynamo] Add `itertools.repeat` via polyfill by jon-chuang · Pull Request #110953 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 10, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110953

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 144ae40 with merge base de3ae93 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2023

Didn't find following labels among repository labels: release note: dynamo

@jon-chuang
Copy link
Collaborator Author

@pytorchbot label "release notes: dynamo"


if len(args) < 2:
# We cannot risk infinite generator being consumed to exhaustion by dynamo
# (i.e. infinite loop)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come? Shouldn't this be triggering the generator protocol?

Copy link
Collaborator Author

@jon-chuang jon-chuang Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, for some reason, it hangs.

I implemented it as:

def repeat(item, count):
    if count is None:
      while True:
        yield item

# Eager
>>> for i, j in enumerate(repeat(3, None)):
...   print(i, j)
...   if i > 5:
...     break
... 
0 3
1 3
2 3
3 3
4 3
5 3
6 3

# dynamo
# *hangs*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that this does not get wrapped into a generator, but rather a ListIteratorVariable for some reason.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, InliningGeneratorTranslator simply yields the entire generator into a list, which is then returned as a ListIteratorVariable:

self.generated_items = []

Copy link
Collaborator Author

@jon-chuang jon-chuang Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will implement general infinite iterators (count, repeat, cycle) in a follow up PR.

These are not preferred as they return an opaque variabletracker. In particular, one cannot do enumerate(repeat(1)). repeat(1, 10) benefits from the integration enjoyed by ListVariableIterator

@jon-chuang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jon-chuang
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2023
…110967)

Fixes https://github.com/pytorch/pytorch/pull/110953/files#r1352868935

Depends on: #110953

Why not use these for `repeat(item, count)`:
> These are not preferred as they return an opaque VariableTracker. In particular, one cannot do `enumerate(repeat(1))`. `repeat(1, 10)` benefits from the integration enjoyed by `ListVariableIterator`

Follow ups:
- [ ] make listiterator an IteratorVariable, define iterator integrations on base IteratorVariable where unspecialized #110967 (comment)
    - Please make a new issue for this
- [ ] explore integrating cpython itertools test suite #110967 (comment)
- [ ] Use something other than `StopIteration` to handle iterator termination #110967 (comment)
- [ ] Add test case for consuming iterator simultaneously from two code points https://github.com/pytorch/pytorch/pull/110967/files#r1358325511

Pull Request resolved: #110967
Approved by: https://github.com/ezyang
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…ytorch#110967)

Fixes https://github.com/pytorch/pytorch/pull/110953/files#r1352868935

Depends on: pytorch#110953

Why not use these for `repeat(item, count)`:
> These are not preferred as they return an opaque VariableTracker. In particular, one cannot do `enumerate(repeat(1))`. `repeat(1, 10)` benefits from the integration enjoyed by `ListVariableIterator`

Follow ups:
- [ ] make listiterator an IteratorVariable, define iterator integrations on base IteratorVariable where unspecialized pytorch#110967 (comment)
    - Please make a new issue for this
- [ ] explore integrating cpython itertools test suite pytorch#110967 (comment)
- [ ] Use something other than `StopIteration` to handle iterator termination pytorch#110967 (comment)
- [ ] Add test case for consuming iterator simultaneously from two code points https://github.com/pytorch/pytorch/pull/110967/files#r1358325511

Pull Request resolved: pytorch#110967
Approved by: https://github.com/ezyang
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ytorch#110967)

Fixes https://github.com/pytorch/pytorch/pull/110953/files#r1352868935

Depends on: pytorch#110953

Why not use these for `repeat(item, count)`:
> These are not preferred as they return an opaque VariableTracker. In particular, one cannot do `enumerate(repeat(1))`. `repeat(1, 10)` benefits from the integration enjoyed by `ListVariableIterator`

Follow ups:
- [ ] make listiterator an IteratorVariable, define iterator integrations on base IteratorVariable where unspecialized pytorch#110967 (comment)
    - Please make a new issue for this
- [ ] explore integrating cpython itertools test suite pytorch#110967 (comment)
- [ ] Use something other than `StopIteration` to handle iterator termination pytorch#110967 (comment)
- [ ] Add test case for consuming iterator simultaneously from two code points https://github.com/pytorch/pytorch/pull/110967/files#r1358325511

Pull Request resolved: pytorch#110967
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[dynamo] support itertools.repeat: call_function repeat in skip_files Builtin repeat, skip reason: filename is None

4 participants