-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[state_dict][11/N] Implement cpu_offload and full_state_dict for get_state_dict #112837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…state_dict As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112837
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 3794436 with merge base 31ded95 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…state_dict As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) ghstack-source-id: 206371083 Pull Request resolved: #112837
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 206465442 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 206762418 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 206788209 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 206890572 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
| - ``fsdp_state_dict_type``: if the model contains FSDP sharded submodules, | ||
| what FSDP state_dict type should be used. | ||
| The default value is SHARDED_STATE_DICT. | ||
| - ``full_state_dict``: if this is set to True, all the tensors in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin n00b question - Assuming AsyncCollectiveTensors are still returned without waiting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the dtensor gathering to use the full_tensor() API in a follow-up PR, which will always return synchronously so we can make sure that the value in the state_dict is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still catching up on the underlying implementations but LGTM!
| - ``fsdp_state_dict_type``: if the model contains FSDP sharded submodules, | ||
| what FSDP state_dict type should be used. | ||
| The default value is SHARDED_STATE_DICT. | ||
| - ``full_state_dict``: if this is set to True, all the tensors in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the dtensor gathering to use the full_tensor() API in a follow-up PR, which will always return synchronously so we can make sure that the value in the state_dict is correct.
| elif not cpu_offload: | ||
| with SimpleProfiler.profile("clone"): | ||
| value = value.detach.clone() | ||
| value = value.detach().clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol. How come our tests haven't not caught this?
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 207082807 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict" As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) [ghstack-poisoned]
…state_dict Pull Request resolved: #112837 As title ghstack-source-id: 207150105 @exported-using-ghexport Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…state_dict (pytorch#112837) As title Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/) Pull Request resolved: pytorch#112837 Approved by: https://github.com/LucasLLC, https://github.com/wz337 ghstack dependencies: pytorch#112836, pytorch#112885
| - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if | ||
| ``full_state_dict`` is also true, then only the rank0 will get the | ||
| state_dict and all other ranks will get empty state_dict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! This is a sensible default, but why wasn't rank0_only exposed in this options class in addition to cpu_offload? For setups with enough RAM, loading the full CPU model weights on all ranks could be desirable.
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a subtle issue that not all users may be aware. So we would like to avoid OOM issues as possible. And in many use cases, when users do full_state_dict, only rank0 is going to save the state_dict. What's the use case for all the ranks to save the duplicated states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For saving, I agree completely.
I was asking about loading. Doesn't this logic also apply during loading?
You might want to cpu-offload the loaded checkpoint, but still have all ranks load it into the model/optimizer.
Stack from ghstack (oldest at bottom):
As title
Differential Revision: D50962991