-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DSD] Implement broadcast_from_rank0 option for model state_dict #125338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125338
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d113109 with merge base 196a0b1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| if pg is None: | ||
| pg = dist.distributed_c10d._get_default_group() | ||
| dist._broadcast_coalesced(pg, tensors, 500, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this _broadcast_coalesced call, does every rank have the full state dict in GPU memory?
I think no matter what, we want to interleave the broadcast with sharding. Otherwise, we either use 8x CPU memory (not mmap'd) across the host or 1x GPU memory per GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I forgot to address overlapping. The new version has already taken care of this issue. However, do you still recommend avoiding broadcast_coalesced and instead using regular broadcast?
| if pg is None: | ||
| pg = dist.distributed_c10d._get_default_group() | ||
| dist._broadcast_coalesced(pg, tensors, 500, 0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another detail: I think we should be careful to use dist._broadcast_coalesced today since it will call recordStream on the input tensors. I think to avoid this, we can either manually use the coalescing context from Python or just use dist.broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our current coalescing context does not support broadcast. I can take a look to see how can we can the support. But even with recordStream, is that really going to cause problem if we only do this when loading the checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we had a debug last week in torchtune around this. for 70b, recordstream during loading checkpoint peaks memeory to 78G. it’s reduced to 22G without recordstream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have time to figure out coalescing ctx though
the above comment of avoiding full sd in GPU is critical and time sensitive
we were doing broadcast+distribute_tensor together in one for loop in torchtune. maybe it also apply here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, right. This implementation may blow out the GPU memory. The new version already change it. If dist._broadcast_coalesced is really an issue, I can change it to the regular broadcast since the current context manager version does not support broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I missed this. User need to set TORCH_NCCL_AVOID_RECORD_STREAMS when using dist._broadcast_coalesced to release GPU memory timely. Maybe regular broadcast to avoid recordstreams before we had a solution for coalescing?
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…25339) Summary: This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict. Pull Request resolved: #125339 Approved by: https://github.com/weifengpy ghstack dependencies: #125708, #125338
| else: | ||
| assert device == value.device | ||
| assert device is not None | ||
| _broadcast_state_dict(state_dict, local_state_dict, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin for meta init model, we will see error because of device=meta. Is there a way to resolve it? eg passing device from StateDictOptions?
| dtype=tensor_info.dtype, | ||
| ) | ||
|
|
||
| tensors.append(full_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this approach, the entire state dict will have to fit into a single GPU? But that's not guaranteed to work. Shouldn't this broadcast and distribute each tensor individually?
There doesn't seem to be a way currently to efficiently load a full state dict checkpoint into a DTensor model with the new torch.distributed.checkpoint APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's not true, we don't need to fit the entire state_dict into a single GPU. The full_tensor is not kept after the broadcasting and we broadcast every 10 tensors now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm simply getting OOM trying to load Llama 70B the way it is currently. So what I did now is load each parameter individually, and this works and saves memory:
https://github.com/Lightning-AI/pytorch-lightning/blob/bd2843f6cbac769ad71b0b6404e411e9844ea9ce/src/lightning/fabric/strategies/model_parallel.py#L541-L553
Stack from ghstack (oldest at bottom):
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC