KEMBAR78
[state_dict][11/N] Implement cpu_offload and full_state_dict for get_state_dict by fegin · Pull Request #112837 · pytorch/pytorch · GitHub
Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Nov 3, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112837

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 3794436 with merge base 31ded95 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

fegin added a commit that referenced this pull request Nov 3, 2023
…state_dict

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

ghstack-source-id: 206371083
Pull Request resolved: #112837
@fegin fegin marked this pull request as draft November 3, 2023 08:05
@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Nov 3, 2023
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 3, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 206465442
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Nov 7, 2023
fegin added a commit that referenced this pull request Nov 7, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 206762418
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 7, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 206788209
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
@fegin fegin marked this pull request as ready for review November 8, 2023 01:31
@fegin fegin requested a review from LucasLLC as a code owner November 8, 2023 01:31
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 8, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 206890572
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
- ``fsdp_state_dict_type``: if the model contains FSDP sharded submodules,
what FSDP state_dict type should be used.
The default value is SHARDED_STATE_DICT.
- ``full_state_dict``: if this is set to True, all the tensors in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin n00b question - Assuming AsyncCollectiveTensors are still returned without waiting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the dtensor gathering to use the full_tensor() API in a follow-up PR, which will always return synchronously so we can make sure that the value in the state_dict is correct.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still catching up on the underlying implementations but LGTM!

- ``fsdp_state_dict_type``: if the model contains FSDP sharded submodules,
what FSDP state_dict type should be used.
The default value is SHARDED_STATE_DICT.
- ``full_state_dict``: if this is set to True, all the tensors in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the dtensor gathering to use the full_tensor() API in a follow-up PR, which will always return synchronously so we can make sure that the value in the state_dict is correct.

elif not cpu_offload:
with SimpleProfiler.profile("clone"):
value = value.detach.clone()
value = value.detach().clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. How come our tests haven't not caught this?

…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 9, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 207082807
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
@fegin fegin removed the release notes: distributed (fsdp) release notes category label Nov 10, 2023
…ct for get_state_dict"

As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 10, 2023
…state_dict

Pull Request resolved: #112837

As title
ghstack-source-id: 207150105
@exported-using-ghexport

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
@facebook-github-bot facebook-github-bot deleted the gh/fegin/177/head branch November 16, 2023 15:30
Comment on lines +89 to +91
- ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
``full_state_dict`` is also true, then only the rank0 will get the
state_dict and all other ranks will get empty state_dict.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! This is a sensible default, but why wasn't rank0_only exposed in this options class in addition to cpu_offload? For setups with enough RAM, loading the full CPU model weights on all ranks could be desirable.

Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a subtle issue that not all users may be aware. So we would like to avoid OOM issues as possible. And in many use cases, when users do full_state_dict, only rank0 is going to save the state_dict. What's the use case for all the ranks to save the duplicated states?

Copy link
Contributor

@carmocca carmocca Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For saving, I agree completely.

I was asking about loading. Doesn't this logic also apply during loading?
You might want to cpu-offload the loaded checkpoint, but still have all ranks load it into the model/optimizer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants