KEMBAR78
Fix param and buffer mapping for state_dict when there are state_dict hooks by yushangdi · Pull Request #137609 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Oct 9, 2024

Resolve #137540

Summary:

We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks.
For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks.
To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict().
Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically.

nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy.

One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have mod.layers[3].layer.sa_norm.scale.

But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in mod.state_dict() to reflect the static model definition, so we have mod.state_dict()["layers.3.sa_norm.scale"].
In this Diff, we change ExportedProgram to populate its state_dict using named_parameters() and named_buffers() instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy.

Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly.

In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict.

The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from torch.export.exported_program instead of model.state_dict() if the model has state_dict hooks.
The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from _disabled_load_state_dict_hooks(M) context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state.
If a model doesn't have any state_dict hooks, one can still use model.state_dict() for weight swapping, so it's BC.

Test Plan:

buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_for_training_with_state_dict_hooks

Differential Revision: D64080561

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137609

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 0468fa1 with merge base 93bbc8a (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like for the FQN, we store the post-processed, state_dict key, not necessarily the path where the attribute exists. From the discussion I'm not 100% sure what the semantics are, and if this breaks unflattening? Do you know if export() + unflatten() works for the test case you added?

cc: @angelayi

Copy link
Contributor Author

@yushangdi yushangdi Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed in chat, summarize the discussion here:

we can modify the verifier to match against the state_dict without the hook

  • add a test for unflattening
  • add a context manager, so in export, module’s state_dict hook is removed. So now in verifiers, we are matching against the state_dict without hook. Effectively, we ignore state_dict_hooks in export.

One caveat is then for any model with state_dict hook, one won’t be able to interchange between export_program.state_dict() & mod.state_dict().

export_program.state_dict() still works with it self, but you can’t load a model’s state_dict to exported_program, or vice versa.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could add a context manager around this call that removes/puts back the state dict hooks on mod upon enter/exit, that way we don't have to make the nn.Module changes

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what's the intended usage for this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what's the intended usage for this function?

This this for weight swapping.

At time T1: you're running export, getting a exported program
At time T2: Some serving service serves an artifact from the exported program
At time T3: there is a recurring training job that just finished and updates the model state that is stored.
At time T4: the serving service is going to pick up the same compiled artifact with the new state that was just updated.

This is used to store the new model state at time T3.

something like,

ep = export(model)
d = exported_program_state_dict(model)
# update ep's state_dict with d.

Copy link
Contributor

@pianpwk pianpwk Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, what I had mind with this was to wrap it around some broad chunk of code in export (maybe _export_func), so that anyone working on export who doesn't know about this issue can just call state_dict(). That way we also don't have to do the manual construction from named_parameters + named_buffers. But I'm happy with the chunk of code in _trace.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this test case could match what we were seeing before with more modules? Like if a user is trying to remove a layer out of the state dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this test case could match what we were seeing before with more modules? Like if a user is trying to remove a layer out of the state dict

fixed now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this private for now, and not in this file... maybe in utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this private for now, and not in this file... maybe in utils?

moved to utils now.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

@yushangdi yushangdi requested review from angelayi and pianpwk October 10, 2024 16:07
Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pushing this through!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def _disabled_load_state_dict_hooks(mod: torch.nn.Module):
def _disable_load_state_dict_hooks(mod: torch.nn.Module):

Comment on lines 923 to 929
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think we need this function, we can just tell ppl to directly use the disable hook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, removed now.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 10, 2024
yushangdi added a commit to yushangdi/pytorch that referenced this pull request Oct 10, 2024
… hooks (pytorch#137609)

Summary:

We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks.

For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks.

To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict().


Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically.

nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy.

One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition".  In runtime, we have `mod.layers[3].layer.sa_norm.scale`.
But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in `mod.state_dict()` to reflect the static model definition, so we have `mod.state_dict()["layers.3.sa_norm.scale"]`.


In this Diff, we change ExportedProgram to populate its state_dict using `named_parameters()` and `named_buffers()` instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy.

Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly.

In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict.


The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from `torch.export.exported_program` instead of `model.state_dict()` if the model has state_dict hooks.

The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `torch._export.utils._disabled_load_state_dict_hooks` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state.

Example:
```
        with _disabled_load_state_dict_hooks(M):
            state_dict = M.state_dict()
```

If a model doesn't have any state_dict hooks, one can still use `model.state_dict()` for weight swapping, so it's BC.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_for_training_with_state_dict_hooks
```

Reviewed By: angelayi

Differential Revision: D64080561
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

… hooks (pytorch#137609)

Summary:

We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks.

For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks.

To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict().


Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically.

nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy.

One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition".  In runtime, we have `mod.layers[3].layer.sa_norm.scale`.
But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in `mod.state_dict()` to reflect the static model definition, so we have `mod.state_dict()["layers.3.sa_norm.scale"]`.


In this Diff, we change ExportedProgram to populate its state_dict using `named_parameters()` and `named_buffers()` instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy.

Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly.

In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict.


The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from `torch.export.exported_program` instead of `model.state_dict()` if the model has state_dict hooks.

The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `torch._export.utils._disabled_load_state_dict_hooks` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state.

Example:
```
        with _disabled_load_state_dict_hooks(M):
            state_dict = M.state_dict()
```

If a model doesn't have any state_dict hooks, one can still use `model.state_dict()` for weight swapping, so it's BC.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_for_training_with_state_dict_hooks
```

Reviewed By: angelayi, pianpwk

Differential Revision: D64080561
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64080561

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

export_for_training regression on Llama3_2_vision text decoder

5 participants