[modeling_utils] use less cpu memory with sharded checkpoint loading #16844
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR lowers the peak cpu memory usage for sharded checkpoint loading
The following demonstration tells the full story. I'm using
/usr/bin/time -f %Mto report max rss = total cpu memory used by the process including peak memory.This demo uses T0 which is 42GB big in fp32 https://huggingface.co/bigscience/T0/tree/main
So with the normal loading the program needs 87GB of CPU RAM (42x2 plus a few GBs for temps)
So after this PR the CPU memory usage is 1x model size (42GB here) + largest shard (10GB) + some temps = 53GB
Before this PR we were getting an additional 15GB (1.5x shard) of peak cpu memory.
@sgugger