KEMBAR78
compile: ban mutations on non-compositional uses of as_strided by bdhirsh · Pull Request #122502 · pytorch/pytorch · GitHub
Skip to content

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Mar 22, 2024

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was another view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122502

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dbbeb3a with merge base 69c6e0b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test failures look real

…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Mar 22, 2024
@albanD albanD removed their request for review March 29, 2024 02:29
…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
}},
/*is_multi_output=*/{str(is_multi_output_view).lower()}
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

man, this discount c++ boolean printing LOL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah... lol

…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, is_multi_output, out_idx);
return ViewMeta(forward_fn, reverse_fn, is_multi_output, is_as_strided, out_idx);
Copy link
Contributor Author

@bdhirsh bdhirsh Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was such a pain to debug, bleh (out_idx was getting default-initialized to zero everywhere, so we were silently using the wrong out idx in multi-view inverses)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh oops, should have caught this in code review. Next time you should add new optional parameters at the end

input_geometry.sym_sizes(), input_geometry.sym_strides());
auto result_slice =
result.as_strided_symint(sizes, strides, std::move(storage_offset));
auto result_buffer = grad_.new_zeros_symint(input_geometry.sym_sizes());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe cc @albanD - it looks like we were in fact using "non-compositional as_strided + mutation" in the backward formula for as_strided_scatter (lol), which failed a few tests.

I think this change is a no-op, but tagging you just in case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool code :)

…ided"

Fixes #104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
input_geometry.sym_sizes(), input_geometry.sym_strides());
auto result_slice =
result.as_strided_symint(sizes, strides, std::move(storage_offset));
auto result_buffer = grad_.new_zeros_symint(input_geometry.sym_sizes());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool code :)

@bdhirsh bdhirsh added release notes: composability release notes category topic: bc breaking topic category labels Apr 11, 2024
pytorchmergebot pushed a commit that referenced this pull request Apr 12, 2024
Fixes #122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.

Pull Request resolved: #122751
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: #122502
pytorchmergebot pushed a commit that referenced this pull request Apr 12, 2024
Fixes #123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: #123348
Approved by: https://github.com/ezyang
ghstack dependencies: #122502, #122751
pytorchmergebot pushed a commit that referenced this pull request Apr 15, 2024
…#123347)

Fixes #122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: #123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…ch#122502)

Fixes pytorch#104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)

Pull Request resolved: pytorch#122502
Approved by: https://github.com/ezyang, https://github.com/albanD
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…ch#122751)

Fixes pytorch#122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.

Pull Request resolved: pytorch#122751
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: pytorch#122502
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…123348)

Fixes pytorch#123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: pytorch#123348
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#122502, pytorch#122751
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…ch#122502)

Fixes pytorch#104505

I was originally going to ban all usages of as_strided + mutation in functionalization. But I'm pretty sure that as_strided + mutation is fine when we are calling as_strided on a base tensor.

So in this PR I added a slightly more conservative check: if we see an as_strided + mutation, where the input to an as_strided was **another** view op, then I error loudly in functionalization and link to the github issue above (in case anyone runs into this in the real world)

Pull Request resolved: pytorch#122502
Approved by: https://github.com/ezyang, https://github.com/albanD
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…ch#122751)

Fixes pytorch#122379

It looks like `iter_contains()` in dynamo expects to take in something like `iter_contains(List[VariableTracker], VariableTracker])`. Previously, when we called this function where the list in question was a `RangeVariable`, we would pass in `RangeVariable.items` as our list.

This is wrong, though since `RangeVariable.items` just contains the underlying [start, stop, step]. It looks like `unpack_var_sequence` does the right thing of "materializing" the range into a list of `VariableTrackers`, so I used that instead.

Pull Request resolved: pytorch#122751
Approved by: https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: pytorch#122502
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…123348)

Fixes pytorch#123298

I was also seeing some crashes in torchtrain due to dynamic shapes, even when I set `compile(dynamic=False)` (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @wanchaol). This doesn't fix the underlying dynamic shape issues with compile + DTensor, but it does prevent dynamic shapes from leaking in.

Pull Request resolved: pytorch#123348
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#122502, pytorch#122751
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…pytorch#123347)

Fixes pytorch#122459, pytorch/torchtitan#61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: pytorch#123347
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#122502, pytorch#122751, pytorch#123348
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants