-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True)
+ tests
#16657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tests/test_modeling_common.py
Outdated
import threading # noqa | ||
|
||
|
||
class CPUMemoryTracker: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I put CPUMemoryTracker in testing_utils.py before I polish this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my question is moot.
Unfortunately, I think I have to discard the measuring test (leaving the functional one in place), since measuring cpu memory is super fickle - the test works well on my desktop but fails on CI.
I tried another version with tracemalloc
but it doesn't work well either.
If re-run the same tests in a loop I get different numbers due to memory being cached - gc.collect()
doesn't seem to help.
I may try some more tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. It looked like a great tool though! Too bad CPU usage can't be measured without being so finnicky.
The documentation is not available anymore as the PR was closed or merged. |
Hmm, so trying to write a test that shows the memory saving proved to be a puzzle. Getting inconsistent results between my desktop and the CI. That was using I think I may try The results are very peculiar:
but at least that explains why my memory tracking wasn't showing the saving consistently since I was using a 0.5GB model for the test. So what I'm doing is:
update: The culprit proved be that my original low_cpu_mem code was not able to handle models with a custom prefix in its keys like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for adding those tests!
tests/test_modeling_common.py
Outdated
import threading # noqa | ||
|
||
|
||
class CPUMemoryTracker: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. It looked like a great tool though! Too bad CPU usage can't be measured without being so finnicky.
src/transformers/modeling_utils.py
Outdated
return submodule, split_key[0] | ||
|
||
|
||
def move_model_to_meta(model, loaded_state_dict_keys, start_prefix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original _load_pretrained_model_low_mem
hack got split into 2 functions, one that moves the model to meta and another that replaces the specific keys on meta to loaded state_dict keys.
that way I was able to integrate this functionality into the normal complex code of checking the keys and everything else.
from_pretrained(..., low_cpu_mem_usage)
+ tests
from_pretrained(..., low_cpu_mem_usage)
+ testsfrom_pretrained(..., low_cpu_mem_usage=True)
+ tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR has moved in a totally different direction as the original intent, touching the code method of Transformers in a way that is hard to read in a git diff
. Although the changes are welcome, I'm not sure we can catch any regression in this highly sensible code in the current format.
Could we temporarily revert the refactor and first merge this PR with the test. Then have a PR that refactors the missing key part in a function as you did without code changes, and finally do the code changes in a third PR?
The quality test will not work as the original implementation doesn't work with bert or any other model with its custom I totally hear you about the complexity and that the PR is difficult to review in several places. So I propose this plan:
does that sound OK? |
That sounds right, thanks for understanding! |
Step 1 is ready: #16706 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Since it's core, would need to have @LysandreJik and @patrickvonplaten look at this as well, to make sure we don't break anything.
return retrieved_modules | ||
|
||
@staticmethod | ||
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this method is called in src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
- might be nice to change it the standard one now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, Patrick
It's all modular now, so if you agree we can add a convenience wrapper:
@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do:
1. save which state_dict keys we have
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
return error_msgs
which restores the original function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and if so, how can I test src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went ahead and added it, so just need to test that conversion script once I know how.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe it's a bit overkill to test the script since the model is huge and it's just a conversion script which are not tested anyways 😅 I'd be fine with just changing the function and "trusting" that it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't test conversion scripts. (and the conversion script shouldn'tuse a private method from modeling_utils, missed that in the review...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably indicates a need for a low memory usage model update from state_dict functionality. Perhaps once it's exercised some more we can make it a public util function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this! Super useful feature.
Left two nits, but I'd also be ok with merging this anyways as there are not too important
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks great to me. Thanks for refactoring this, @stas00, the new method-based approach is very clean.
…` + tests (huggingface#16657) * add low_cpu_mem_usage tests * wip: revamping * wip * install /usr/bin/time * wip * cleanup * cleanup * cleanup * cleanup * cleanup * fix assert * put the wrapper back * cleanup; switch to bert-base-cased * Trigger CI * Trigger CI
The initial
from_pretrained(..., low_cpu_mem_usage=True)
implementation was a quick hack to enable loading gptj models on low CPU memory setups. It didn't work with all models.This PR takes it one step further. It revamps the implementation to support many features it wasn't supporting by revamping the implementation and delegating all the work to the normal
from_pretrained
code path except the final step ofstate_dict
=> model param overwrite.This PR:
low_cpu_mem_usage=True
from_pretrained(mname, low_cpu_mem_usage=True)
works with sharded and non-sharded checkpointlow_cpu_mem_usage=True
uses less memory.The low cpu memory usage code path is still not 100% complete feature-wise, but it's getting there. Though I'm contemplating a different approach to solving the issue of low cpu memory. That is by introducing several new
from_pretrained
args that should allow loading the model and/orstate_dict
directly on GPU for single GPU or DDP. But that's for another PR.@sgugger, @LysandreJik, @patrickvonplaten