-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Make cache_config not mandatory #40316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make cache_config not mandatory #40316
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Converting to draft, I think there are other places in the code that need the same treatment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that static cache config has been removed starting from #39106 and apparently we don't anymore need to pass anything to initiate a Cache class except for max_cache_len
@gante are we going to get rid of cache config totally? In that case we can update export to require *kwargs for earlt initialization
|
@zucchini-nlp cache config as a plain dict is still supported, as a means to fully parameterize the desired cache from a [However, due to all recent changes, I would expect bugs in its utilization 🤗 ] |
|
Considering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can indeed remove the warnings, the defaults you gave are sound IMO
|
Let's do the same in |
+1, I'd expect it fail as well but miraculously the tests aren't throwing errors with Static Cache |
|
Took care of the comments above 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we able to update the failing tests instead and specify a cache_config on the generation_config? I think it's good to explicitly specify max cache len since it might be hard to debug issues later on if the user is unaware that there is a default value of 4096
76e1f18 to
eb00c25
Compare
|
@jackzhxng if the |
cb7a379 to
1b8ba14
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using defaults lgtm as an easy-to-setup export for beginner users. We can add documentation on how advanced users can configure model cache and change the default values
@jackzhxng does that sound good?
1b8ba14 to
1e6f647
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as well, but let's wait for feedback of @jackzhxng as well before merging.
@jackzhxng, why don't you think the defaults are a good idea? IMO, since the exportable classes only take the model as argument, it's a bit awkward to force the user to previously set the cache_config in the generation_config in the model itself before exporting - I do like the defaults as well. Any particular reason you were against?
|
|
||
| cache_config = {} if generation_config.cache_config is None else generation_config.cache_config | ||
| batch_size = cache_config.get("batch_size", 1) | ||
| max_cache_len = cache_config.get("max_cache_len", 4096) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we default to something, should we default to model.config.max_position_embeddings instead? I would prefer no default at all, to be honest, the maximum length is often highly application-specific and a bad number is quite impactful (incorrect model behaviour and/or wasted memory)
For instance, 4096 is too large for some models (e.g., will start generating gibberish after 2k tokens) and too small for some use cases (like processing documents)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.config.max_position_embeddings is wayyyy too big for some models IMO, we'll waste so much we could OOM on cache init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many models (even super small ones) have more than 100k https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, and that's why I would prefer to have no defaults 😢 I don't think there is a good value that works on 99% of the cases
A missing value with a good exception is trivial to fix. A bad default is a pain to find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No default == crash is a no no IMO 😉
We can check if it has eot -> will stop generating itself, if not we use a good default for the number of parameter, something simple like 7B -> 512 since its fast
14B -> 256
etc just to have not too much time spend but still generate a bit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we default to something, should we default to model.config.max_position_embeddings instead? I would prefer no default at all, to be honest, the maximum length is often highly application-specific and a bad number is quite impactful (incorrect model behaviour and/or wasted memory)
I think to add onto this, not only is a it quite impactful, it's also very tricky to debug the further down you go in the stack, especially in ExecuTorch running a model. As @gante said, max_seq_len is one of, if not the most important export variables, so I think it's good to clear confusion about this value early on by erroring out instead of setting an arbitrary default value that is hard to find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep we need to either make the default value very very easy to find, or error out. Weighting both in both contexts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, from this discussion it looks like more people are pro-erroring-out. WDYT about making them arguments of the exportable class though to simplify? Instead if having to set model.generation_config.cache_config which is a bit awkward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following this discussion, I removed the defaults and added the batch_size, max_cache_len and device as args. In order to not break the current API, if those arguments are not passed, we check the cache_config for corresponding values, and error-out if they are not found.
1e6f647 to
d5d2f66
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma, gemma2, gemma3, qwen2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect thanks! I think this version makes everyone happy!
With recent changes introduced in #39836 the tests related to
executorch.pysuch astest_export_text_only_with_hybrid_cachewill fail if the modelgeneration_confighas nocache_configor if any of the required attribute is missing fromcache_config. This is the case for models such has https://huggingface.co/google/gemma-3-4b-it which is already used ingemma3models test.This PR aims to remove the requirement of
cache_configby restoring old default values, albeit while throwing a warning so that users are aware these attributes are missing from the config. This will ensure backwards-compatibility for old models without having to edit their hub config.cc. @jackzhxng @zucchini-nlp as this relates to #39836