-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Pipelining] fix extra memory usage in zero bubble #138119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138119
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 3fc6898 with merge base deaf041 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I'm wondering what we can do as far as unit testing for memory. What if we put schedules aside and really focus on unit-level testing- We have a stage with a simple model, we know roughly how much activation memory is needed, and we set up hooks on cuda caching allocator to see if we allocate/free memory as expected after calling stage._forward / dI / dW. And we do this with / without AC? |
I think you may not even need to hook into the caching allocator. Here is how we did unit testing for memory for FSDP2:
|
i would probably benefit from a graph diagram showing what is changed by setting outputs to detach=True. taking a stab, i think 'outputs' is the place we start the .grad computation from when we first do dInput calculation. Then after dInput, we have to save some intermetiate nodes, but we shouldn't need to save outputs, and we also shouldn't need to save any other intermediate nodes that aren't needed for dW. Then when we do dW calculation, we do not start from outputs, we start from saved intermediates. Is that right? and does detaching the outputs free just the output node, or it also frees other parts of the autograd graph that are not needed for dW? Also, how does loss fit in? You didn't change 'loss' in this PR but I think you mentioned that we might need to clear loss tensor after printing it. Can you explain more about that? |
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing In zero bubble, we have two methods `stage_backward_input` and `stage_backward_weight`. During `stage_backward_input` we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where `retain_graph=False`). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore. Pre-fix: <img width="1021" alt="image" src="https://github.com/user-attachments/assets/6c8bf469-32b1-4dac-85ff-b97991f9f0e3"> Post-fix: <img width="1039" alt="image" src="https://github.com/user-attachments/assets/a1875038-e80b-4dd4-84f2-38727d7792dc"> without AC (7B model on titan): 10% memory improvement with AC (7B model on titan) 50% memory improvement cc XilunWu awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing In zero bubble, we have two methods `stage_backward_input` and `stage_backward_weight`. During `stage_backward_input` we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where `retain_graph=False`). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore. Pre-fix: <img width="1021" alt="image" src="https://github.com/user-attachments/assets/6c8bf469-32b1-4dac-85ff-b97991f9f0e3"> Post-fix: <img width="1039" alt="image" src="https://github.com/user-attachments/assets/a1875038-e80b-4dd4-84f2-38727d7792dc"> without AC (7B model on titan): 10% memory improvement with AC (7B model on titan) 50% memory improvement cc XilunWu awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Yep! This is all correct.
It frees the other parts of the autograd graph as well
For the last stage, |
This Q&A is great! |
This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier. pre-fix: <img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab"> post-fix: <img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c"> Pull Request resolved: #138504 Approved by: https://github.com/wconstab ghstack dependencies: #138119
…last stage" This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier. pre-fix: <img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab"> post-fix: <img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c"> cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This fix is similar to that done in #138119, except this is an edge case for the last stage. For the last stage we perform backward on the `loss` which we detached in the previous PR. However, we also hold the `stage_outputs` alive because we return all the output chunks in `merge_output_chunks()` after the step is over. This will also still keep the autograd graph alive, so detaching these tensors frees the memory earlier. pre-fix: <img width="1780" alt="image" src="https://github.com/user-attachments/assets/bb78bde7-fd5c-4eba-bfc9-f0359e20bbab"> post-fix: <img width="1788" alt="image" src="https://github.com/user-attachments/assets/a26102d9-9db2-4fc8-946c-336b8430657c"> cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Addressing the comments in previous PRs to update the variable names and add additional code comments Pull Request resolved: #138735 Approved by: https://github.com/wconstab ghstack dependencies: #138119, #138504
Pull Request resolved: #138720 Approved by: https://github.com/wconstab ghstack dependencies: #138119, #138504, #138735
Stack from ghstack (oldest at bottom):
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing
In zero bubble, we have two methods
stage_backward_input
andstage_backward_weight
. Duringstage_backward_input
we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B whereretain_graph=False
). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.Pre-fix:

Post-fix:

without AC (7B model on titan):
10% memory improvement
with AC (7B model on titan)
50% memory improvement
cc @XilunWu @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o