-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Provide option to pass module instance to _load_state_dict_pre_hooks. #62070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide option to pass module instance to _load_state_dict_pre_hooks. #62070
Conversation
We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 57a8430 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) ghstack-source-id: 134152233 Pull Request resolved: #62070
torch/nn/modules/module.py
Outdated
| n_params = len(signature(hook).parameters) | ||
| if n_params == 8: | ||
| # Pass in module instance as well. | ||
| hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit hacky, another option I had in mind was to have an additional option in _register_load_state_dict_pre_hook something like with_module and store that option with the hook. Then during load time, based on that option we either pass in the module or we don't for the hook.
…_pre_hooks." We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) [ghstack-poisoned]
…_pre_hooks." We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) [ghstack-poisoned]
Pull Request resolved: #62070 We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. ghstack-source-id: 134344366 Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
torch/nn/modules/module.py
Outdated
| """ | ||
| handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) | ||
| self._load_state_dict_pre_hooks[handle.id] = hook | ||
| self._load_state_dict_pre_hooks[handle.id] = (hook, with_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't you use partial here to just update the hook functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this a bit more, the problem is that this is a custom user hook and we don't have any easy way to know the name of the last parameter (unless we use inspect). If we use partial, we probably need to do something like hook = partial(hook, <last_param_name>='self')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, you're writing a brand new API, you can make the module param the very first positional argument ;)
…_pre_hooks." We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) [ghstack-poisoned]
Pull Request resolved: #62070 We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. ghstack-source-id: 134381635 Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good! Couple minor comments on testing below
…_pre_hooks." We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/) [ghstack-poisoned]
Pull Request resolved: #62070 We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. ghstack-source-id: 134541391 Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
This pull request has been merged in cac4aa7. |
Stack from ghstack:
We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.
However, the problem is during load time _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.
To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.
Differential Revision: D29867142