-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[pipelining] fix py ref cycle in stage_backward #136507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136507
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7277b30 with merge base failed to retrieve merge base, please contact dev infra: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
TLDR; found forward activation tensors were being kept alive "forever" (or until GC ran), and tracked it down to a cycle involving `stage_backward.<locals>.extract_tensors_with_grads`. More debug details: https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing In pdb: ``` gc.collect() g = gc.garbage g[-1] [rank0]:(Pdb) [rank0]:<function stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0> g[-2] [rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at 0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at 0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1 d6340>) g[-3] [rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06, 2.6226e-06, ..., 6.4969e-06, [rank0]: -4.4405e-06, -4.7684e-06], ``` ghstack-source-id: e671d63 Pull Request resolved: #136507
| del stage_output_tensors | ||
| del output_grad_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, is this related to how stage_backward is called? i.e. lambda function
pytorch/torch/distributed/pipelining/stage.py
Lines 490 to 495 in 3be1506
| if backward_type == "full": | |
| return lambda: stage_backward( | |
| bwd_kwargs["stage_output"], | |
| bwd_kwargs["output_grads"], | |
| bwd_kwargs["input_values"], | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, this is unrelated.
I'm adding additional information to the doc and PR-desc
| torch.autograd.backward( | ||
| stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] | ||
| ) | ||
| del stage_output_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after digging deeper into what the ref cycle was, I prefer to change this to del extract_tensors_with_grads which breaks the cycle and ends up freeing the tensors as well as other pyobjs that would have remained in a cycle if i only del stage_output_tensors and del output_grad_tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i've updated the PR with a fix that is unfortunately a bit uglier, but arguably more correct. If folks have a better idea, let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok- one more update: i don't really need to pass stage_output_tensors and output_grad_tensors into extract_tensors_with_grads. I'll let those be captured in the cell. I just need to ensure the cell will die, and for that i just have to break the cycle caused by the self-reference. And that's a little less ugly.
TLDR; found forward activation tensors were being kept alive "forever" (or until GC ran), and tracked it down to a cycle involving `stage_backward.<locals>.extract_tensors_with_grads`. More debug details: https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing In pdb: ``` gc.collect() g = gc.garbage g[-1] [rank0]:(Pdb) [rank0]:<function stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0> g[-2] [rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at 0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at 0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1 d6340>) g[-3] [rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06, 2.6226e-06, ..., 6.4969e-06, [rank0]: -4.4405e-06, -4.7684e-06], ``` ghstack-source-id: ede5009 Pull Request resolved: #136507
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging further! Wow, what a new learning!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| # and to itself (extract_tensors_with_grads) since it makes a recursive call | ||
| # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad | ||
| # fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore | ||
| extract_tensors_with_grads( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you passed in stage_output_tensors and output_grad_tensors without the extract_tensors_with_grads ref, would that also fix the reference cycle? Since then extract_tensors_with_grads will no longer have to access variables in the outer scope so there should be no cycle, or is my understanding off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, that would fix the tensor lifetime problem, which is almost as good. However technically, there would still be a reference cycle between extract_tensors_with_grads and the cell. This cycle would live on until GC kicked in. So breaking the cycle is both cleaner (less vars passed in than passing in the two lists), and better (not only prevents leaking the tensors, but also cleans up the pyobjs earlier)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up: can we move extract_tensors_with_grads out of the scope of stage_backward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but, then we'd be forced to add the other 2 lists back as args to extract_tensors_with_grads, but we could delete the extract_tensors_with_grads argument since extract_tensors_with_grads is a global now instead of a cellvar. If you guys want to make that change, its fine too.
…s in tests" Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details). This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress. Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles. Uses objgraph for a nice debug utility when a leak is found. Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak. I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker. Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py, and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`: ``` warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle? warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes) Graph viewer (xdot) not found, generating a png instead Image generated as /tmp/objgraph-ztz642h3.png ``` rendering of ` /tmp/objgraph-ztz642h3.png`: <img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22"> cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details). This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress. Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles. Uses objgraph for a nice debug utility when a leak is found. Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak. I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker. Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py, and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`: ``` warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle? warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes) Graph viewer (xdot) not found, generating a png instead Image generated as /tmp/objgraph-ztz642h3.png ``` rendering of ` /tmp/objgraph-ztz642h3.png`: <img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22"> cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
…s in tests" Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details). This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress. Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles. Uses objgraph for a nice debug utility when a leak is found. Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak. I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker. Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py, and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`: ``` warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle? warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes) Graph viewer (xdot) not found, generating a png instead Image generated as /tmp/objgraph-ztz642h3.png ``` rendering of ` /tmp/objgraph-ztz642h3.png`: <img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22"> cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details). This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress. Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles. Uses objgraph for a nice debug utility when a leak is found. Credit to H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak. I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker. Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py, and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`: ``` warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle? warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes) Graph viewer (xdot) not found, generating a png instead Image generated as /tmp/objgraph-ztz642h3.png ``` rendering of ` /tmp/objgraph-ztz642h3.png`: <img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22"> cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
`stage_backward.<locals>.extract_tensors_with_grads`.
The reference cycle in question is below. (constructed using gc.get_referrers after doing a gc.collect in gc debug mode)
tensor is kept alive by
`[(<class 'cell'>, '0x7f7360234400')]`
tuple of cell objects
`(<cell at 0x7f73602343d0: function object at 0x7f734fff0ee0>, <cell at 0x7f7360234400: list object at 0x7f734e4d9a80>, <cell at 0x7f73602a4190: list object at 0x7f734eff8b00>)`
is kept alive by
`[(<class 'function'>, '0x7f734fff0ee0')]`
`<function stage_backward.<locals>.extract_tensors_with_grads at 0x7f734fff0ee0>`
is kept alive by
`[(<class 'cell'>, '0x7f73602343d0')]`
Put into more plain terms,
```
def stage_backward(...):
...
stage_output_tensors = []
# a cell object will exist that contains the variables defined in stage_backward and used by
# both stage_backward and nested functions
# in this case, the cell object contains 'stage_output_tensors' but
# this function object will hold a reference to a 'cell' that contains any vars from
# the parent scope not explicitly passed into the function as args.
def extract_tensors_with_grads(...):
...
# extract_tensors_with_grads refers to stage_output_tensors, so stage_output_tensors
# is in the cell
stage_output_tensors.append(output_val)
...
# but extract_tensors_with_grads ALSO refers to itself (extract_tensors_with_grads),
# so `extract_tensors_with_grads` will be in the cell
extract_tensors_with_grads(...)
```
More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing
In pdb:
```
gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06, 2.6226e-06,
..., 6.4969e-06,
[rank0]: -4.4405e-06, -4.7684e-06],
```
Pull Request resolved: pytorch#136507
Approved by: https://github.com/awgu, https://github.com/kwen2501
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details). This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress. Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles. Uses objgraph for a nice debug utility when a leak is found. Credit to @H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak. I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker. Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py, and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`: ``` warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle? warnings.warn( /data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes) Graph viewer (xdot) not found, generating a png instead Image generated as /tmp/objgraph-ztz642h3.png ``` rendering of ` /tmp/objgraph-ztz642h3.png`: <img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22"> Pull Request resolved: #136584 Approved by: https://github.com/kwen2501, https://github.com/H-Huang ghstack dependencies: #136507 Co-authored-by: Howard Huang <howardhuang@fb.com>
Pull Request resolved: #136678 Approved by: https://github.com/wconstab, https://github.com/kwen2501 ghstack dependencies: #136507, #136584
Stack from ghstack (oldest at bottom):
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
stage_backward.<locals>.extract_tensors_with_grads.The reference cycle in question is below. (constructed using gc.get_referrers after doing a gc.collect in gc debug mode)
tensor is kept alive by
[(<class 'cell'>, '0x7f7360234400')]tuple of cell objects
(<cell at 0x7f73602343d0: function object at 0x7f734fff0ee0>, <cell at 0x7f7360234400: list object at 0x7f734e4d9a80>, <cell at 0x7f73602a4190: list object at 0x7f734eff8b00>)is kept alive by
[(<class 'function'>, '0x7f734fff0ee0')]<function stage_backward.<locals>.extract_tensors_with_grads at 0x7f734fff0ee0>is kept alive by
[(<class 'cell'>, '0x7f73602343d0')]Put into more plain terms,
More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing
In pdb:
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o