KEMBAR78
[Pipelining] fix extra memory usage in zero bubble by H-Huang · Pull Request #138119 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Oct 16, 2024

Stack from ghstack (oldest at bottom):

Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing

In zero bubble, we have two methods stage_backward_input and stage_backward_weight. During stage_backward_input we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where retain_graph=False). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.

Pre-fix:
image

Post-fix:
image

without AC (7B model on titan):
10% memory improvement

with AC (7B model on titan)
50% memory improvement

cc @XilunWu @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

H-Huang added a commit that referenced this pull request Oct 16, 2024
ghstack-source-id: 5bca83f
Pull Request resolved: #138119
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138119

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 3fc6898 with merge base deaf041 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 16, 2024
@H-Huang H-Huang added release notes: distributed (pipeline) release notes category module: pipelining Pipeline Parallelism labels Oct 16, 2024
@wconstab
Copy link
Contributor

I'm wondering what we can do as far as unit testing for memory.

What if we put schedules aside and really focus on unit-level testing- We have a stage with a simple model, we know roughly how much activation memory is needed, and we set up hooks on cuda caching allocator to see if we allocate/free memory as expected after calling stage._forward / dI / dW. And we do this with / without AC?

@awgu
Copy link
Collaborator

awgu commented Oct 16, 2024

I think you may not even need to hook into the caching allocator. Here is how we did unit testing for memory for FSDP2:

def _test_fully_shard_training_memory(

@wconstab
Copy link
Contributor

To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.

i would probably benefit from a graph diagram showing what is changed by setting outputs to detach=True.

taking a stab, i think 'outputs' is the place we start the .grad computation from when we first do dInput calculation. Then after dInput, we have to save some intermetiate nodes, but we shouldn't need to save outputs, and we also shouldn't need to save any other intermediate nodes that aren't needed for dW. Then when we do dW calculation, we do not start from outputs, we start from saved intermediates.

Is that right? and does detaching the outputs free just the output node, or it also frees other parts of the autograd graph that are not needed for dW?

Also, how does loss fit in? You didn't change 'loss' in this PR but I think you mentioned that we might need to clear loss tensor after printing it. Can you explain more about that?

@H-Huang H-Huang changed the title [PP] fix extra memory usage in zero bubble [Ppipelining] fix extra memory usage in zero bubble Oct 21, 2024
@H-Huang H-Huang changed the title [Ppipelining] fix extra memory usage in zero bubble [Pipelining] fix extra memory usage in zero bubble Oct 21, 2024
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing

In zero bubble, we have two methods `stage_backward_input` and `stage_backward_weight`. During `stage_backward_input` we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where `retain_graph=False`). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.

Pre-fix:
<img width="1021" alt="image" src="https://github.com/user-attachments/assets/6c8bf469-32b1-4dac-85ff-b97991f9f0e3">

Post-fix:
<img width="1039" alt="image" src="https://github.com/user-attachments/assets/a1875038-e80b-4dd4-84f2-38727d7792dc">

without AC (7B model on titan):
10% memory improvement

with AC (7B model on titan)
50% memory improvement


cc XilunWu awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing

In zero bubble, we have two methods `stage_backward_input` and `stage_backward_weight`. During `stage_backward_input` we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where `retain_graph=False`). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.

Pre-fix:
<img width="1021" alt="image" src="https://github.com/user-attachments/assets/6c8bf469-32b1-4dac-85ff-b97991f9f0e3">

Post-fix:
<img width="1039" alt="image" src="https://github.com/user-attachments/assets/a1875038-e80b-4dd4-84f2-38727d7792dc">

without AC (7B model on titan):
10% memory improvement

with AC (7B model on titan)
50% memory improvement


cc XilunWu awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: a726d83
Pull Request resolved: #138119
@H-Huang
Copy link
Member Author

H-Huang commented Oct 21, 2024

Then after dInput, we have to save some intermetiate nodes, but we shouldn't need to save outputs, and we also shouldn't need to save any other intermediate nodes that aren't needed for dW. Then when we do dW calculation, we do not start from outputs, we start from saved intermediates.

Yep! This is all correct.

does detaching the outputs free just the output node, or it also frees other parts of the autograd graph that are not needed for dW?

It frees the other parts of the autograd graph as well

Also, how does loss fit in?

For the last stage, loss is passed in instead of the stage_outputs on which we call autograd.grad() and so this change will also detach the loss tensor since we don't need it anymore after the first backward. I was investigating why the last stage still uses more memory and I think I found the issue in the follow up PR

@H-Huang H-Huang marked this pull request as ready for review October 21, 2024 21:13
@H-Huang H-Huang requested review from kwen2501 and wconstab October 21, 2024 21:13
@kwen2501
Copy link
Contributor

Then after dInput, we have to save some intermetiate nodes, but we shouldn't need to save outputs, and we also shouldn't need to save any other intermediate nodes that aren't needed for dW. Then when we do dW calculation, we do not start from outputs, we start from saved intermediates.

Yep! This is all correct.

does detaching the outputs free just the output node, or it also frees other parts of the autograd graph that are not needed for dW?

It frees the other parts of the autograd graph as well

Also, how does loss fit in?

For the last stage, loss is passed in instead of the stage_outputs on which we call autograd.grad() and so this change will also detach the loss tensor since we don't need it anymore after the first backward. I was investigating why the last stage still uses more memory and I think I found the issue in the follow up PR

This Q&A is great!
I wonder if we could put it as comment in the code? Thanks!

pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2024
This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier.

pre-fix:
<img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab">

post-fix:
<img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c">

Pull Request resolved: #138504
Approved by: https://github.com/wconstab
ghstack dependencies: #138119
H-Huang added a commit that referenced this pull request Oct 24, 2024
…last stage"


This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier.

pre-fix:
<img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab">

post-fix:
<img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c">



cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Oct 24, 2024
This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier.

pre-fix:
<img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab">

post-fix:
<img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c">



cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Oct 25, 2024
Addressing the comments in previous PRs to update the variable names and add additional code comments

Pull Request resolved: #138735
Approved by: https://github.com/wconstab
ghstack dependencies: #138119, #138504
pytorchmergebot pushed a commit that referenced this pull request Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants