KEMBAR78
Provide option to pass module instance to _load_state_dict_pre_hooks. by pritamdamania87 · Pull Request #62070 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pritamdamania87
Copy link
Contributor

@pritamdamania87 pritamdamania87 commented Jul 23, 2021

Stack from ghstack:

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: D29867142

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 23, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 57a8430 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jul 28 21:21:23 2021-07-28 21:21:23.696091: E t...w/compiler/xla/service/slow_operation_alarm.cc:55]
Jul 28 21:04:39   test_AvgPool3d_backward_after_cat_dim1_device_xla (__main__.TestNNDeviceTypeXLA) ... skip (0.002s)
Jul 28 21:04:39   test_BatchNorm_empty_xla (__main__.TestNNDeviceTypeXLA) ... ok (0.153s)
Jul 28 21:04:39   test_Bilinear_empty_xla (__main__.TestNNDeviceTypeXLA) ... skip (0.003s)
Jul 28 21:04:39   test_CTCLoss_cudnn_xla (__main__.TestNNDeviceTypeXLA) ... skip (0.003s)
Jul 28 21:04:40   test_CTCLoss_empty_target_xla (__main__.TestNNDeviceTypeXLA) ... ok (0.836s)
Jul 28 21:06:51   test_Conv2d_backward_depthwise_xla_float64 (__main__.TestNNDeviceTypeXLA) ... 2021-07-28 21:06:51.834510: E tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
Jul 28 21:06:51 ********************************
Jul 28 21:06:51 Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Jul 28 21:06:51 Compiling module SyncTensorsGraph.30789
Jul 28 21:06:51 ********************************
Jul 28 21:21:23 2021-07-28 21:21:23.696091: E tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
Jul 28 21:21:23 ********************************
Jul 28 21:21:23 Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Jul 28 21:21:23 Compiling module SyncTensorsGraph.35441
Jul 28 21:21:23 ********************************


Too long with no output (exceeded 1h30m0s): context deadline exceeded


1 job timed out:

  • pytorch_xla_linux_bionic_py3_6_clang9_test

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

pritamdamania87 pushed a commit that referenced this pull request Jul 23, 2021
We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

ghstack-source-id: 134152233
Pull Request resolved: #62070
n_params = len(signature(hook).parameters)
if n_params == 8:
# Pass in module instance as well.
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, self)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky, another option I had in mind was to have an additional option in _register_load_state_dict_pre_hook something like with_module and store that option with the hook. Then during load time, based on that option we either pass in the module or we don't for the hook.

…_pre_hooks."

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

[ghstack-poisoned]
…_pre_hooks."

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 26, 2021
Pull Request resolved: #62070

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.
ghstack-source-id: 134344366

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
"""
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
self._load_state_dict_pre_hooks[handle.id] = hook
self._load_state_dict_pre_hooks[handle.id] = (hook, with_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you use partial here to just update the hook functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this a bit more, the problem is that this is a custom user hook and we don't have any easy way to know the name of the last parameter (unless we use inspect). If we use partial, we probably need to do something like hook = partial(hook, <last_param_name>='self')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you're writing a brand new API, you can make the module param the very first positional argument ;)

…_pre_hooks."

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 27, 2021
Pull Request resolved: #62070

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.
ghstack-source-id: 134381635

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good! Couple minor comments on testing below

…_pre_hooks."

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 28, 2021
Pull Request resolved: #62070

We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.

However, the problem is during load time  _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.

To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.
ghstack-source-id: 134541391

Differential Revision: [D29867142](https://our.internmc.facebook.com/intern/diff/D29867142/)
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in cac4aa7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants