KEMBAR78
Updated _load_pretrained_model_low_mem to check if keys are in the state_dict by FrancescoSaverioZuppichini · Pull Request #16643 · huggingface/transformers · GitHub
Skip to content

Conversation

@FrancescoSaverioZuppichini
Copy link
Contributor

What does this PR do?

This PR checks if any key is in the state_dict before attempting to load it. If we have multiple checkpoints, not all keys are in every checkpoint.

TODO

  • tests

@FrancescoSaverioZuppichini
Copy link
Contributor Author

I am wondering what is the correct place to add a test for this function

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 7, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling this, the fix could be a tiny bit better I believe.

@stas00 It looks like the whole low_cpu_mem_usage is not tested at present? Maybe we can take care of tests in a separate PR for both a whole and a sharded checkpoint, so this can be merged fast for the RegNet PR?

if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
if k in state_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should go above on line 2165 with a continue if it's not True, to avoid looking for the param when we don't need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. the only difference to your comment is setattr(submodule, param_name, new_val) is after the check for the key

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing on line 2165, are you sure you pushed your update? The goal is to avoid spending any time in this block (starting at submodule, param_name = find_submodule_and_param_name(model, k)) when there is no need to.

Copy link
Contributor Author

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, updated. No need for ugly continue when you can do everything with a positive conditional flow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefilter?

keys_to_load = [k for k in loaded_state_dict_keys if k in state_dict]

Copy link
Contributor

@stas00 stas00 Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it won't be the same if loaded_state_dict_keys doesn't include all state_dict keys. I'm pretty sure it is right now, but it may change. Note this warning:

Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.

it was a quick hack to enable an urgent use so it needs to be completed to do a full support, in which case not all keys from state_dict might be loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only suggested the comprehension way as another way to avoid too much conditional nesting.

continue is there for this exact reason and a functional programming tool

Copy link
Contributor Author

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will have to put your continue inside an if statement. For me is the same, feel free to suggest the change that fits your coding style preference and I will happily change it. But, let's avoid unneeded nitpicking

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggested a simple alternative to deep conditional nesting here: #16643 (comment)

But I'm fine with the code the way it is now as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, what I meant is that prefiltering is the same as just iterating the loaded state_dict keys, that is the cleanest solution

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2022

Your plan works for me, Sylvain. I will work on the low mem test then today.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for fixing this bug, @FrancescoSaverioZuppichini

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini deleted the _load_pretrained_model_low_mem branch April 7, 2022 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants