-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DSD] Support flattening the optimizer state_dict when saving and unflattening when loading #127071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127071
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 36f0ed9 with merge base a60b06b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fqns = param_group.pop(PARAMS) | ||
for fqn in fqns: | ||
for k, v in param_group.items(): | ||
ret[f"{PG}.{fqn}.{k}"] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this flattening would not preserve the original param-group groupings. It is a true flattening. Is that correct?
I'm not sure if that's good or bad- does that mean that we make the optimizer states generic enough that resharing would be possible? (But I'm not sure how we'd reconstruct the original param_group groupings on load side.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok reading the 'unflattening' side again it makes sense. IIUC, because the user would have created the new optimizer based on the new model/parallelisms, it may or may not have the same param_groupings as before. But since we normalized all of the individual states to fqn, as long as they have the same FQNs (or less FQNs) in their new model, they can find the states and then group them according to the new optimizer's groupings.
So this appears to support resharding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This supports resharding.
A general comment- it seems like the amount of optimizer-specific knowledge we have to bake in here is small- namely that there is a concept of ordered lists of parameter groupings and a particular schema for it. This seems probably ok to me, even though we'd rather just have a clean dict from the optimizer in the first place and keep this code simpler. I'm assuming this is true for all of our optimizers though. If the optimizers vary and not all of them have param groupings, then this could be fragile. |
Yes, the implementation is based on all the built-in optimizers, SGD, Adam, Adamw, but should also support optimizers from Apex. However, one thing I'm not so sure is that if we will have issues with second-order optimizers. My guess is no as the second-order optimizers generally have more complicated states but the param group structure is the same. But I don't know if this is true for all the optimizers. cc., @wz337 |
|
||
def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: | ||
def _raise_if_type_not_supported(v): | ||
if not isinstance(v, (torch.Tensor, int, float)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should other primitive types be ok to support? like double, complex, bool?
for param_group in state_dict[PG]: | ||
fqns = param_group.pop(PARAMS) | ||
for fqn in fqns: | ||
for k, v in param_group.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a guarantee from optimizer that within one param_group, there will not be duplicate param fqns?
can we also be sure that one param fqn will not show up in 2 param_groups?
is it good to assert f"{PG}.{fqn}.{k} not in ret
before assigning v?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is guaranteed. But we can definitely add an assert here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is guaranteed today, we have test cases for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. i dont think its too risky, but happy to let other DCP/optimizer experts weigh in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligned offline about the solution. LGTM.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…zer_state_dict (#127384) Summary: Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues. Pull Request resolved: #127384 Approved by: https://github.com/wz337 ghstack dependencies: #127070, #127071
…lattening when loading (#127071) Fixes #126595 **What does this PR do?** This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one: ``` { "state": { "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, }, "param_group": [ {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]} ] } ``` While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism. This PR flatten the optimizer state_dict to the one as the following one: ``` { "state.layer1.weight.step": 10, "state.layer2.weight.step": 10, "state.layer1.weight.exp_avg": SomeTensor, "state.layer2.weight.exp_avg": SomeTensor, "state.layer1.weight.exp_avg_sq": SomeTensor, "state.layer2.weight.exp_avg_sq": SomeTensor, "param_group.layer1.weight.lr" : 0.1, "param_group.layer2.weight.lr" : 0.1, "param_group.layer1.weight.betas" : (0.9, 0.95), "param_group.layer2.weight.betas" : (0.9, 0.95), } ``` This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism). **Pros and Cons** *Pros* 1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers). 2. User don't need to manually add prefix to different optimizer. 3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism. *Cons* 1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken. 2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save. Pull Request resolved: #127071 Approved by: https://github.com/wconstab, https://github.com/wz337 ghstack dependencies: #127070 (cherry picked from commit bd868ee)
…zer_state_dict (#127384) Summary: Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues. Pull Request resolved: #127384 Approved by: https://github.com/wz337 ghstack dependencies: #127070, #127071 (cherry picked from commit 8b4ad3a)
…itialized case (#127385) Fixes #124942 Summary: Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False. Pull Request resolved: #127385 Approved by: https://github.com/wz337 ghstack dependencies: #127070, #127071, #127384 (cherry picked from commit 64c581a)
…lattening when loading (pytorch#127071) Fixes pytorch#126595 **What does this PR do?** This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one: ``` { "state": { "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor}, }, "param_group": [ {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]} ] } ``` While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism. This PR flatten the optimizer state_dict to the one as the following one: ``` { "state.layer1.weight.step": 10, "state.layer2.weight.step": 10, "state.layer1.weight.exp_avg": SomeTensor, "state.layer2.weight.exp_avg": SomeTensor, "state.layer1.weight.exp_avg_sq": SomeTensor, "state.layer2.weight.exp_avg_sq": SomeTensor, "param_group.layer1.weight.lr" : 0.1, "param_group.layer2.weight.lr" : 0.1, "param_group.layer1.weight.betas" : (0.9, 0.95), "param_group.layer2.weight.betas" : (0.9, 0.95), } ``` This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism). **Pros and Cons** *Pros* 1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers). 2. User don't need to manually add prefix to different optimizer. 3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism. *Cons* 1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken. 2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save. Pull Request resolved: pytorch#127071 Approved by: https://github.com/wconstab, https://github.com/wz337 ghstack dependencies: pytorch#127070
…zer_state_dict (pytorch#127384) Summary: Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues. Pull Request resolved: pytorch#127384 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#127070, pytorch#127071
…itialized case (pytorch#127385) Fixes pytorch#124942 Summary: Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False. Pull Request resolved: pytorch#127385 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#127070, pytorch#127071, pytorch#127384
Stack from ghstack (oldest at bottom):
Fixes #126595
What does this PR do?
This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current
get_optimizer_state_dict()
converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one:While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is
param_group
.param_group
cannot generate unique keys during saving -- DCP will flatten the dict but forparam_group
, DCP will get the keys like,param_group.lr
orparam_group.params
. These keys will conflict when using pipeline parallelism.This PR flatten the optimizer state_dict to the one as the following one:
This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism).
Pros and Cons
Pros
Cons
param_groups
and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken.param_group
generally contains scalars which are cheap to save.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC