KEMBAR78
[dynamo] Identify pre-existing captured cells by cell id rather than content id by StrongerXi · Pull Request #140436 · pytorch/pytorch · GitHub
Skip to content

Conversation

@StrongerXi
Copy link
Contributor

@StrongerXi StrongerXi commented Nov 12, 2024

Stack from ghstack (oldest at bottom):

In match_nested_cell, Dynamo tried to identify pre-existing captured
cells by (cell_name, id(cell_contents)). This works in most cases, but
as the test added in this patch shows, it's not a complete solution.

This patch

  1. changes match_nested_cell to lookup_variable_for_captured_cell,
    and does the lookup based on id of cell objects, not their contents.
    This requires plumbing a tuple of captured cell objects from
    different CPython versions all the way to
    InstructionTranslator.__init__, where we store a mapping from the
    ids of these cell objects, and use it later in
    UserFunctionVariable.bind_args to look for these unboxed cells.
  2. builds off (1) -- rather than using a VariableTracker that
    represents the content of the unboxed cells, use ClosureVariable,
    which enables codegen in case these cells escape as closure of a
    NestedUserFunctionVariable.

The patch adds a regression test for each of the scenarios above:

  1. test_write_to_cells_with_name_shadowing where Dynamo mistakenly
    thought the program is writing to a cell captured by root frame (which
    it doesn't support atm), which resulted in
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF
    unimplemented("write to __closure__ while inlining")
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: write to __closure__ while inlining
  1. test_existing_func_that_creates_capturing_nested_func where Dynamo
    ended up trying to codegen a NestedUserFunctionVariable that
    captures a cell which was also captured by the root frame, so it was
    unboxed and ends up emitting LOAD_DEREF rather than
    LOAD_FAST/LOAD_CLOSURE during codegen, resulting in
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn
    func = FunctionType(code, f_globals, name, defaults, closure)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: arg 5 (closure) expected cell, found int

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140436

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f2c003e with merge base f98c601 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
@StrongerXi
Copy link
Contributor Author

Add another regression test.

[ghstack-poisoned]
@StrongerXi StrongerXi requested a review from jansel November 13, 2024 12:46
[ghstack-poisoned]
PyCodeObject* code = self->frame->f_code;
// Why this check? See
// https://github.com/python/cpython/blob/5f24da9d75bb0150781b17ee4706e93e6bb364ea/Objects/frameobject.c#L1058-L1065
if (code->co_flags & CO_OPTIMIZED) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some code duplication with

if (co->co_flags & CO_OPTIMIZED) {
, but it's not a huge deal

[ghstack-poisoned]
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
pytorch#140435)

Registed tensor hooks contain `NestedUserFunctionVariable` which might
capture a `NewCellVariable` for cell objects created during Dynamo
tracing, so we must make sure it doesn't get pruned away.

Pull Request resolved: pytorch#140435
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
   their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
   `STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
   by a newly created function Dynamo is about to inline. It's a handle
   with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
   to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
   pre-existing function Dynamo is about to inline, represent those
   cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
  frame (the cells are passed in as input to `InstructionTranslator`),
  this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
  level (guards) and bytecode level (codegen) auto-dereferencing
  behavior, when accessing pre-existing python cells. This also
  involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
  we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
  of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
  which is important for readability, performance, and some
  assumptions `bytecode_transformation.py` makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes pytorch#137123.

Pull Request resolved: pytorch#140153
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
…140154)

Now that all cells are modeled as `NewCellVariable` in Dynamo, we no
longer need to put cell variables into this special `closure_cells`,
rather we just merge `closure_cells` with `symbolic_locals`.

This allows us to merge and remove some code paths, notably make
`LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF`
the same for inlining or regular `InstructionTranslator`.

Pull Request resolved: pytorch#140154
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
…ytorch#140155)

This is no longer needed now that we've replaced `ClosureVariable` with
`NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as
`LOAD_FAST`.

Pull Request resolved: pytorch#140155
Approved by: https://github.com/jansel, https://github.com/williamwen42
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…content id (pytorch#140436)

In `match_nested_cell`, Dynamo tried to identify pre-existing captured
cells by `(cell_name, id(cell_contents))`. This works in most cases, but
as the test added in this patch shows, it's not a complete solution.

This patch
1. changes `match_nested_cell` to `lookup_variable_for_captured_cell`,
   and does the lookup based on id of cell objects, not their contents.
   This requires plumbing a tuple of captured cell objects from
   different CPython versions all the way to
   `InstructionTranslator.__init__`, where we store a mapping from the
   ids of these cell objects, and use it later in
   `UserFunctionVariable.bind_args` to look for these unboxed cells.
2. builds off (1) -- rather than using a `VariableTracker` that
   represents the content of the unboxed cells, use `ClosureVariable`,
   which enables codegen in case these cells escape as closure of a
   `NestedUserFunctionVariable`.

The patch adds a regression test for each of the scenarios above:
1. `test_write_to_cells_with_name_shadowing` where Dynamo mistakenly
   thought the program is writing to a cell captured by root frame (which
   it doesn't support atm), which resulted in
```
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF
    unimplemented("write to __closure__ while inlining")
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: write to __closure__ while inlining
```
2. `test_existing_func_that_creates_capturing_nested_func` where Dynamo
   ended up trying to codegen a `NestedUserFunctionVariable` that
   captures a cell which was also captured by the root frame, so it was
   unboxed and ends up emitting `LOAD_DEREF` rather than
   `LOAD_FAST/LOAD_CLOSURE` during codegen, resulting in
```
  File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn
    func = FunctionType(code, f_globals, name, defaults, closure)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: arg 5 (closure) expected cell, found int
```

Pull Request resolved: pytorch#140436
Approved by: https://github.com/jansel, https://github.com/williamwen42
ghstack dependencies: pytorch#140330, pytorch#140152
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pytorch#140435)

Registed tensor hooks contain `NestedUserFunctionVariable` which might
capture a `NewCellVariable` for cell objects created during Dynamo
tracing, so we must make sure it doesn't get pruned away.

Pull Request resolved: pytorch#140435
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
   their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
   `STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
   by a newly created function Dynamo is about to inline. It's a handle
   with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
   to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
   pre-existing function Dynamo is about to inline, represent those
   cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
  frame (the cells are passed in as input to `InstructionTranslator`),
  this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
  level (guards) and bytecode level (codegen) auto-dereferencing
  behavior, when accessing pre-existing python cells. This also
  involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
  we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
  of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
  which is important for readability, performance, and some
  assumptions `bytecode_transformation.py` makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes pytorch#137123.

Pull Request resolved: pytorch#140153
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…140154)

Now that all cells are modeled as `NewCellVariable` in Dynamo, we no
longer need to put cell variables into this special `closure_cells`,
rather we just merge `closure_cells` with `symbolic_locals`.

This allows us to merge and remove some code paths, notably make
`LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF`
the same for inlining or regular `InstructionTranslator`.

Pull Request resolved: pytorch#140154
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ytorch#140155)

This is no longer needed now that we've replaced `ClosureVariable` with
`NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as
`LOAD_FAST`.

Pull Request resolved: pytorch#140155
Approved by: https://github.com/jansel, https://github.com/williamwen42
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
@github-actions github-actions bot deleted the gh/StrongerXi/35/head branch December 19, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants