KEMBAR78
fix correctness for dynamo inlining RangeVariable __contains__ by bdhirsh · Pull Request #122751 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Mar 27, 2024

Fixes #122379

It looks like iter_contains() in dynamo expects to take in something like iter_contains(List[VariableTracker], VariableTracker]). Previously, when we called this function where the list in question was a RangeVariable, we would pass in RangeVariable.items as our list.

This is wrong, though since RangeVariable.items just contains the underlying [start, stop, step]. It looks like unpack_var_sequence does the right thing of "materializing" the range into a list of VariableTrackers, so I used that instead.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122751

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 81edfe3 with merge base 69c6e0b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.


def test_contains_range_constprop(self):
def fn(x):
# dynamo should const prop to False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, True*

@albanD albanD removed their request for review March 27, 2024 22:21
@ezyang ezyang requested review from anijain2305 and jansel and removed request for ezyang March 28, 2024 13:20
…ns__"

Fixes #122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
bdhirsh added 3 commits April 4, 2024 08:45
…ns__"

Fixes #122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…ns__"

Fixes #122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…ns__"

Fixes #122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Apr 12, 2024
Fixes #123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: #123348
Approved by: https://github.com/ezyang
ghstack dependencies: #122502, #122751
pytorchmergebot pushed a commit that referenced this pull request Apr 15, 2024
…#123347)

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: #123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…ch#122751)

Fixes pytorch#122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.

Pull Request resolved: pytorch#122751
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: pytorch#122502
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…123348)

Fixes pytorch#123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: pytorch#123348
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#122502, pytorch#122751
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…ch#122751)

Fixes pytorch#122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.

Pull Request resolved: pytorch#122751
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: pytorch#122502
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…123348)

Fixes pytorch#123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: pytorch#123348
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#122502, pytorch#122751
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants