KEMBAR78
[DSD] Support flattening the optimizer state_dict when saving and unflattening when loading by fegin · Pull Request #127071 · pytorch/pytorch · GitHub
Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented May 24, 2024

Stack from ghstack (oldest at bottom):

Fixes #126595

What does this PR do?
This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current get_optimizer_state_dict() converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one:

{
    "state": {
          "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
          "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
    },
    "param_group": [
         {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]}
    ]
}

While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is param_group. param_group cannot generate unique keys during saving -- DCP will flatten the dict but for param_group, DCP will get the keys like, param_group.lr or param_group.params. These keys will conflict when using pipeline parallelism.

This PR flatten the optimizer state_dict to the one as the following one:

{
    "state.layer1.weight.step": 10,
    "state.layer2.weight.step": 10,
    "state.layer1.weight.exp_avg": SomeTensor,
    "state.layer2.weight.exp_avg": SomeTensor,    
    "state.layer1.weight.exp_avg_sq": SomeTensor,
    "state.layer2.weight.exp_avg_sq": SomeTensor,        
    "param_group.layer1.weight.lr" : 0.1,
    "param_group.layer2.weight.lr" : 0.1,
    "param_group.layer1.weight.betas" : (0.9, 0.95),
    "param_group.layer2.weight.betas" : (0.9, 0.95),    
}

This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism).

Pros and Cons
Pros

  1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers).
  2. User don't need to manually add prefix to different optimizer.
  3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism.

Cons

  1. The implementation has a strong assumption of the structure of param_groups and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken.
  2. There will be extra values saved in the checkpoints. The assumption here is param_group generally contains scalars which are cheap to save.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127071

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 36f0ed9 with merge base a60b06b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 24, 2024
fegin added a commit that referenced this pull request May 24, 2024
Summary:
This allows DSD to support MPMD (e.g., pipeline parallelism).

ghstack-source-id: d12038a
Pull Request resolved: #127071
@fegin fegin marked this pull request as draft May 24, 2024 07:16
fqns = param_group.pop(PARAMS)
for fqn in fqns:
for k, v in param_group.items():
ret[f"{PG}.{fqn}.{k}"] = v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this flattening would not preserve the original param-group groupings. It is a true flattening. Is that correct?

I'm not sure if that's good or bad- does that mean that we make the optimizer states generic enough that resharing would be possible? (But I'm not sure how we'd reconstruct the original param_group groupings on load side.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok reading the 'unflattening' side again it makes sense. IIUC, because the user would have created the new optimizer based on the new model/parallelisms, it may or may not have the same param_groupings as before. But since we normalized all of the individual states to fqn, as long as they have the same FQNs (or less FQNs) in their new model, they can find the states and then group them according to the new optimizer's groupings.

So this appears to support resharding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This supports resharding.

@wconstab
Copy link
Contributor

A general comment- it seems like the amount of optimizer-specific knowledge we have to bake in here is small- namely that there is a concept of ordered lists of parameter groupings and a particular schema for it. This seems probably ok to me, even though we'd rather just have a clean dict from the optimizer in the first place and keep this code simpler.

I'm assuming this is true for all of our optimizers though. If the optimizers vary and not all of them have param groupings, then this could be fragile.

@fegin
Copy link
Contributor Author

fegin commented May 24, 2024

A general comment- it seems like the amount of optimizer-specific knowledge we have to bake in here is small- namely that there is a concept of ordered lists of parameter groupings and a particular schema for it. This seems probably ok to me, even though we'd rather just have a clean dict from the optimizer in the first place and keep this code simpler.

I'm assuming this is true for all of our optimizers though. If the optimizers vary and not all of them have param groupings, then this could be fragile.

Yes, the implementation is based on all the built-in optimizers, SGD, Adam, Adamw, but should also support optimizers from Apex. However, one thing I'm not so sure is that if we will have issues with second-order optimizers. My guess is no as the second-order optimizers generally have more complicated states but the param group structure is the same. But I don't know if this is true for all the optimizers. cc., @wz337

[ghstack-poisoned]
fegin added a commit that referenced this pull request May 24, 2024
Summary:
This allows DSD to support MPMD (e.g., pipeline parallelism).

ghstack-source-id: 72b559a
Pull Request resolved: #127071
@fegin fegin requested review from LucasLLC and wz337 May 24, 2024 16:42
@fegin fegin changed the title [RFC][DSD] Unflatten the optimizer state_dict [RFC][DSD] Support flattening the optimizer state_dict when saving and unflattening when loading May 24, 2024

def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]:
def _raise_if_type_not_supported(v):
if not isinstance(v, (torch.Tensor, int, float)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should other primitive types be ok to support? like double, complex, bool?

for param_group in state_dict[PG]:
fqns = param_group.pop(PARAMS)
for fqn in fqns:
for k, v in param_group.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a guarantee from optimizer that within one param_group, there will not be duplicate param fqns?

can we also be sure that one param fqn will not show up in 2 param_groups?

is it good to assert f"{PG}.{fqn}.{k} not in ret before assigning v?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is guaranteed. But we can definitely add an assert here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is guaranteed today, we have test cases for it.

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. i dont think its too risky, but happy to let other DCP/optimizer experts weigh in

@fegin fegin marked this pull request as ready for review May 28, 2024 17:15
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested a review from janeyx99 May 29, 2024 18:20
Copy link
Contributor

@wz337 wz337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligned offline about the solution. LGTM.

@fegin fegin changed the title [RFC][DSD] Support flattening the optimizer state_dict when saving and unflattening when loading [DSD] Support flattening the optimizer state_dict when saving and unflattening when loading May 31, 2024
@fegin
Copy link
Contributor Author

fegin commented May 31, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 31, 2024
…zer_state_dict (#127384)

Summary:
Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues.

Pull Request resolved: #127384
Approved by: https://github.com/wz337
ghstack dependencies: #127070, #127071
pytorchmergebot pushed a commit that referenced this pull request May 31, 2024
…itialized case (#127385)

Fixes #124942

Summary:
Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False.

Pull Request resolved: #127385
Approved by: https://github.com/wz337
ghstack dependencies: #127070, #127071, #127384
bigfootjon pushed a commit that referenced this pull request Jun 5, 2024
…lattening when loading (#127071)

Fixes #126595

**What does this PR do?**
This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one:
```
{
    "state": {
          "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
          "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
    },
    "param_group": [
         {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]}
    ]
}
```
While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism.

This PR flatten the optimizer state_dict to the one as the following one:
```
{
    "state.layer1.weight.step": 10,
    "state.layer2.weight.step": 10,
    "state.layer1.weight.exp_avg": SomeTensor,
    "state.layer2.weight.exp_avg": SomeTensor,
    "state.layer1.weight.exp_avg_sq": SomeTensor,
    "state.layer2.weight.exp_avg_sq": SomeTensor,
    "param_group.layer1.weight.lr" : 0.1,
    "param_group.layer2.weight.lr" : 0.1,
    "param_group.layer1.weight.betas" : (0.9, 0.95),
    "param_group.layer2.weight.betas" : (0.9, 0.95),
}
```
This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism).

**Pros and Cons**
*Pros*
1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers).
2. User don't need to manually add prefix to different optimizer.
3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism.

*Cons*
1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken.
2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save.

Pull Request resolved: #127071
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #127070

(cherry picked from commit bd868ee)
bigfootjon pushed a commit that referenced this pull request Jun 5, 2024
…zer_state_dict (#127384)

Summary:
Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues.

Pull Request resolved: #127384
Approved by: https://github.com/wz337
ghstack dependencies: #127070, #127071

(cherry picked from commit 8b4ad3a)
bigfootjon pushed a commit that referenced this pull request Jun 5, 2024
…itialized case (#127385)

Fixes #124942

Summary:
Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False.

Pull Request resolved: #127385
Approved by: https://github.com/wz337
ghstack dependencies: #127070, #127071, #127384

(cherry picked from commit 64c581a)
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…lattening when loading (pytorch#127071)

Fixes pytorch#126595

**What does this PR do?**
This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one:
```
{
    "state": {
          "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
          "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
    },
    "param_group": [
         {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]}
    ]
}
```
While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism.

This PR flatten the optimizer state_dict to the one as the following one:
```
{
    "state.layer1.weight.step": 10,
    "state.layer2.weight.step": 10,
    "state.layer1.weight.exp_avg": SomeTensor,
    "state.layer2.weight.exp_avg": SomeTensor,
    "state.layer1.weight.exp_avg_sq": SomeTensor,
    "state.layer2.weight.exp_avg_sq": SomeTensor,
    "param_group.layer1.weight.lr" : 0.1,
    "param_group.layer2.weight.lr" : 0.1,
    "param_group.layer1.weight.betas" : (0.9, 0.95),
    "param_group.layer2.weight.betas" : (0.9, 0.95),
}
```
This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism).

**Pros and Cons**
*Pros*
1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers).
2. User don't need to manually add prefix to different optimizer.
3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism.

*Cons*
1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken.
2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save.

Pull Request resolved: pytorch#127071
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: pytorch#127070
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…zer_state_dict (pytorch#127384)

Summary:
Allow the optim_state_dict argument to be a positional argument. This make sense since this is a required argument and this will make the function signature the consistent as set_model_state_dict without causing BC issues.

Pull Request resolved: pytorch#127384
Approved by: https://github.com/wz337
ghstack dependencies: pytorch#127070, pytorch#127071
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…itialized case (pytorch#127385)

Fixes pytorch#124942

Summary:
Allow DSD to support loading the regular optimizer state_dict and can be used when torch.distributed.is_initialized() is False.

Pull Request resolved: pytorch#127385
Approved by: https://github.com/wz337
ghstack dependencies: pytorch#127070, pytorch#127071, pytorch#127384
@github-actions github-actions bot deleted the gh/fegin/245/head branch July 1, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants