KEMBAR78
[pipelining] fix py ref cycle in stage_backward by wconstab · Pull Request #136507 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Sep 24, 2024

Stack from ghstack (oldest at bottom):

TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
stage_backward.<locals>.extract_tensors_with_grads.

The reference cycle in question is below. (constructed using gc.get_referrers after doing a gc.collect in gc debug mode)

tensor is kept alive by
[(<class 'cell'>, '0x7f7360234400')]

tuple of cell objects
(<cell at 0x7f73602343d0: function object at 0x7f734fff0ee0>, <cell at 0x7f7360234400: list object at 0x7f734e4d9a80>, <cell at 0x7f73602a4190: list object at 0x7f734eff8b00>)
is kept alive by
[(<class 'function'>, '0x7f734fff0ee0')]

<function stage_backward.<locals>.extract_tensors_with_grads at 0x7f734fff0ee0>
is kept alive by
[(<class 'cell'>, '0x7f73602343d0')]

Put into more plain terms,


def stage_backward(...):
    ...
    stage_output_tensors = []

    # a cell object will exist that contains the variables defined in stage_backward and used by
    # both stage_backward and nested functions
    # in this case, the cell object contains 'stage_output_tensors' but 

    # this function object will hold a reference to a 'cell' that contains any vars from
    # the parent scope not explicitly passed into the function as args.
    def extract_tensors_with_grads(...):
        ...
            # extract_tensors_with_grads refers to stage_output_tensors, so stage_output_tensors
            # is in the cell
            stage_output_tensors.append(output_val)
        ...
            # but extract_tensors_with_grads ALSO refers to itself (extract_tensors_with_grads),
            # so `extract_tensors_with_grads` will be in the cell
            extract_tensors_with_grads(...)

More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing

In pdb:

gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06,  2.6226e-06,
...,  6.4969e-06,
[rank0]:          -4.4405e-06, -4.7684e-06],

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 24, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136507

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7277b30 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Sep 24, 2024
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
`stage_backward.<locals>.extract_tensors_with_grads`.

More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing

In pdb:
```
gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06,  2.6226e-06,
...,  6.4969e-06,
[rank0]:          -4.4405e-06, -4.7684e-06],
```

ghstack-source-id: e671d63
Pull Request resolved: #136507
@wconstab wconstab added the release notes: distributed (pipeline) release notes category label Sep 24, 2024
Comment on lines 323 to 324
del stage_output_tensors
del output_grad_tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, is this related to how stage_backward is called? i.e. lambda function

if backward_type == "full":
return lambda: stage_backward(
bwd_kwargs["stage_output"],
bwd_kwargs["output_grads"],
bwd_kwargs["input_values"],
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this is unrelated.
I'm adding additional information to the doc and PR-desc

torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
)
del stage_output_tensors
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after digging deeper into what the ref cycle was, I prefer to change this to del extract_tensors_with_grads which breaks the cycle and ends up freeing the tensors as well as other pyobjs that would have remained in a cycle if i only del stage_output_tensors and del output_grad_tensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've updated the PR with a fix that is unfortunately a bit uglier, but arguably more correct. If folks have a better idea, let me know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok- one more update: i don't really need to pass stage_output_tensors and output_grad_tensors into extract_tensors_with_grads. I'll let those be captured in the cell. I just need to ensure the cell will die, and for that i just have to break the cycle caused by the self-reference. And that's a little less ugly.

[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Sep 24, 2024
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
`stage_backward.<locals>.extract_tensors_with_grads`.

More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing

In pdb:
```
gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06,  2.6226e-06,
...,  6.4969e-06,
[rank0]:          -4.4405e-06, -4.7684e-06],
```

ghstack-source-id: ede5009
Pull Request resolved: #136507
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging further! Wow, what a new learning!

@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# and to itself (extract_tensors_with_grads) since it makes a recursive call
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
# fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore
extract_tensors_with_grads(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you passed in stage_output_tensors and output_grad_tensors without the extract_tensors_with_grads ref, would that also fix the reference cycle? Since then extract_tensors_with_grads will no longer have to access variables in the outer scope so there should be no cycle, or is my understanding off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, that would fix the tensor lifetime problem, which is almost as good. However technically, there would still be a reference cycle between extract_tensors_with_grads and the cell. This cycle would live on until GC kicked in. So breaking the cycle is both cleaner (less vars passed in than passing in the two lists), and better (not only prevents leaking the tensors, but also cleans up the pyobjs earlier)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up: can we move extract_tensors_with_grads out of the scope of stage_backward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but, then we'd be forced to add the other 2 lists back as args to extract_tensors_with_grads, but we could delete the extract_tensors_with_grads argument since extract_tensors_with_grads is a global now instead of a cellvar. If you guys want to make that change, its fine too.

H-Huang added a commit that referenced this pull request Sep 25, 2024
…s in tests"


Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Sep 25, 2024
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Sep 25, 2024
…s in tests"


Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Sep 25, 2024
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
`stage_backward.<locals>.extract_tensors_with_grads`.

The reference cycle in question is below.  (constructed using gc.get_referrers after doing a gc.collect in gc debug mode)

tensor is kept alive by
`[(<class 'cell'>, '0x7f7360234400')]`

tuple of cell objects
`(<cell at 0x7f73602343d0: function object at 0x7f734fff0ee0>, <cell at 0x7f7360234400: list object at 0x7f734e4d9a80>, <cell at 0x7f73602a4190: list object at 0x7f734eff8b00>)`
is kept alive by
`[(<class 'function'>, '0x7f734fff0ee0')]`

`<function stage_backward.<locals>.extract_tensors_with_grads at 0x7f734fff0ee0>`
is kept alive by
`[(<class 'cell'>, '0x7f73602343d0')]`

Put into more plain terms,

```

def stage_backward(...):
    ...
    stage_output_tensors = []

    # a cell object will exist that contains the variables defined in stage_backward and used by
    # both stage_backward and nested functions
    # in this case, the cell object contains 'stage_output_tensors' but

    # this function object will hold a reference to a 'cell' that contains any vars from
    # the parent scope not explicitly passed into the function as args.
    def extract_tensors_with_grads(...):
        ...
            # extract_tensors_with_grads refers to stage_output_tensors, so stage_output_tensors
            # is in the cell
            stage_output_tensors.append(output_val)
        ...
            # but extract_tensors_with_grads ALSO refers to itself (extract_tensors_with_grads),
            # so `extract_tensors_with_grads` will be in the cell
            extract_tensors_with_grads(...)
```

More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing

In pdb:
```
gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06,  2.6226e-06,
...,  6.4969e-06,
[rank0]:          -4.4405e-06, -4.7684e-06],
```

Pull Request resolved: pytorch#136507
Approved by: https://github.com/awgu, https://github.com/kwen2501
pytorchmergebot pushed a commit that referenced this pull request Sep 26, 2024
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to @H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

Pull Request resolved: #136584
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
ghstack dependencies: #136507

Co-authored-by: Howard Huang <howardhuang@fb.com>
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2024
@github-actions github-actions bot deleted the gh/wconstab/336/head branch October 25, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants