-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Updated _load_pretrained_model_low_mem to check if keys are in the state_dict #16643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated _load_pretrained_model_low_mem to check if keys are in the state_dict #16643
Conversation
I am wondering what is the correct place to add a test for this function |
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this, the fix could be a tiny bit better I believe.
@stas00 It looks like the whole low_cpu_mem_usage
is not tested at present? Maybe we can take care of tests in a separate PR for both a whole and a sharded checkpoint, so this can be merged fast for the RegNet PR?
src/transformers/modeling_utils.py
Outdated
if isinstance(getattr(submodule, param_name), torch.nn.Parameter): | ||
new_val = torch.nn.Parameter(new_val) | ||
setattr(submodule, param_name, new_val) | ||
if k in state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test should go above on line 2165 with a continue
if it's not True
, to avoid looking for the param when we don't need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. the only difference to your comment is setattr(submodule, param_name, new_val)
is after the check for the key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing on line 2165, are you sure you pushed your update? The goal is to avoid spending any time in this block (starting at submodule, param_name = find_submodule_and_param_name(model, k)
) when there is no need to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, updated. No need for ugly continue
when you can do everything with a positive conditional flow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefilter?
keys_to_load = [k for k in loaded_state_dict_keys if k in state_dict]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it won't be the same if loaded_state_dict_keys
doesn't include all state_dict keys. I'm pretty sure it is right now, but it may change. Note this warning:
transformers/src/transformers/modeling_utils.py
Line 2121 in 10131af
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. |
it was a quick hack to enable an urgent use so it needs to be completed to do a full support, in which case not all keys from state_dict might be loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only suggested the comprehension way as another way to avoid too much conditional nesting.
continue
is there for this exact reason and a functional programming tool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will have to put your continue
inside an if
statement. For me is the same, feel free to suggest the change that fits your coding style preference and I will happily change it. But, let's avoid unneeded nitpicking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggested a simple alternative to deep conditional nesting here: #16643 (comment)
But I'm fine with the code the way it is now as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, what I meant is that prefiltering is the same as just iterating the loaded state_dict keys, that is the cleanest solution
Your plan works for me, Sylvain. I will work on the low mem test then today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for fixing this bug, @FrancescoSaverioZuppichini
What does this PR do?
This PR checks if any key is in the
state_dict
before attempting to load it. If we have multiple checkpoints, not all keys are in every checkpoint.TODO