KEMBAR78
[DSD] Implement broadcast_from_rank0 option for model state_dict by fegin · Pull Request #125338 · pytorch/pytorch · GitHub
Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented May 1, 2024

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125338

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d113109 with merge base 196a0b1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]

if pg is None:
pg = dist.distributed_c10d._get_default_group()
dist._broadcast_coalesced(pg, tensors, 500, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this _broadcast_coalesced call, does every rank have the full state dict in GPU memory?

I think no matter what, we want to interleave the broadcast with sharding. Otherwise, we either use 8x CPU memory (not mmap'd) across the host or 1x GPU memory per GPU.

Copy link
Contributor Author

@fegin fegin May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I forgot to address overlapping. The new version has already taken care of this issue. However, do you still recommend avoiding broadcast_coalesced and instead using regular broadcast?

if pg is None:
pg = dist.distributed_c10d._get_default_group()
dist._broadcast_coalesced(pg, tensors, 500, 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another detail: I think we should be careful to use dist._broadcast_coalesced today since it will call recordStream on the input tensors. I think to avoid this, we can either manually use the coalescing context from Python or just use dist.broadcast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our current coalescing context does not support broadcast. I can take a look to see how can we can the support. But even with recordStream, is that really going to cause problem if we only do this when loading the checkpoint?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we had a debug last week in torchtune around this. for 70b, recordstream during loading checkpoint peaks memeory to 78G. it’s reduced to 22G without recordstream

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have time to figure out coalescing ctx though

the above comment of avoiding full sd in GPU is critical and time sensitive

we were doing broadcast+distribute_tensor together in one for loop in torchtune. maybe it also apply here

Copy link
Contributor Author

@fegin fegin May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right. This implementation may blow out the GPU memory. The new version already change it. If dist._broadcast_coalesced is really an issue, I can change it to the regular broadcast since the current context manager version does not support broadcast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I missed this. User need to set TORCH_NCCL_AVOID_RECORD_STREAMS when using dist._broadcast_coalesced to release GPU memory timely. Maybe regular broadcast to avoid recordstreams before we had a solution for coalescing?

fegin added 2 commits May 1, 2024 23:19
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented May 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@fegin
Copy link
Contributor Author

fegin commented May 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 8, 2024
…25339)

Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.

Pull Request resolved: #125339
Approved by: https://github.com/weifengpy
ghstack dependencies: #125708, #125338
else:
assert device == value.device
assert device is not None
_broadcast_state_dict(state_dict, local_state_dict, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin for meta init model, we will see error because of device=meta. Is there a way to resolve it? eg passing device from StateDictOptions?

dtype=tensor_info.dtype,
)

tensors.append(full_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this approach, the entire state dict will have to fit into a single GPU? But that's not guaranteed to work. Shouldn't this broadcast and distribute each tensor individually?

There doesn't seem to be a way currently to efficiently load a full state dict checkpoint into a DTensor model with the new torch.distributed.checkpoint APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's not true, we don't need to fit the entire state_dict into a single GPU. The full_tensor is not kept after the broadcasting and we broadcast every 10 tensors now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm simply getting OOM trying to load Llama 70B the way it is currently. So what I did now is load each parameter individually, and this works and saves memory:
https://github.com/Lightning-AI/pytorch-lightning/blob/bd2843f6cbac769ad71b0b6404e411e9844ea9ce/src/lightning/fabric/strategies/model_parallel.py#L541-L553

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants