-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Support input_embeds in torch exportable decoders #39836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support input_embeds in torch exportable decoders #39836
Conversation
49e7e97 to
89e912a
Compare
89e912a to
79e095a
Compare
51c518f to
62da12e
Compare
d3aea48 to
8ea7821
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @jackzhxng , I left a few comments below.Not sure this will work with vision multimodal models right now
| if not config: | ||
| config = model.config | ||
| if not generation_config: | ||
| generation_config = model.generation_config | ||
|
|
||
| if not hasattr(config, "use_cache") or config.use_cache is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we dont need to explicitly pass configs. In case of multimodals, the text decoder config is available via model.config.get_text_config() which will return the text decoder config for any type of model. And the needed generation config is usually the model's own generation config, not the LM's generation config
So we can do
text_config = model.config.get_text_config()
gen_config = model.generation_config
# NOTE: model.language_model will not have an lm_head for all vision multimodal models
# so we need to support exporting the whole model, but without multimodal inputs
TorchExportableModuleForDecoderOnlyLM(model=voxtral) | ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. | ||
| """ | ||
| _, seqlen = input_ids.shape | ||
| position_ids = cache_position.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason to remove position_ids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seemed unnecessary since models like llama already do this in the forward -
| position_ids = cache_position.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah right! As long as it doesn't break export, I'm fine. Wanted to make sure it wasn't deleted accidentally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jack, are you sure all models to this internally? I remember position ids are used for cache updates and not sure if all models do this. lets make sure to verify this
|
Hi @zucchini-nlp, thanks for the review! We are only keeping the decoder portion of multimodal in transformers right now, the rest of the exportable modules we are keeping in Optimum ET right now (huggingface/optimum-executorch#111) to iterate more quickly since there is a lot of work going on around multimodal at the moment |
a2c29fb to
8b06186
Compare
8b06186 to
14610ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating, overall LGTM! I would like someone from optimum to review the PR as well before merging
| ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. | ||
| """ | ||
| _, seqlen = input_ids.shape | ||
| position_ids = cache_position.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah right! As long as it doesn't break export, I'm fine. Wanted to make sure it wasn't deleted accidentally
| max_batch_size (int): Maximum batch size for the cache. | ||
| max_cache_len (int): Maximum sequence length for the cache. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if TorchExportableModuleWithHybridCache is used commonly, we can make a small deprecation cycle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! I believe it's just used for Gemma at the moment
7747302 to
540e187
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma, gemma2, gemma3, llama, olmo, phi3, qwen2, qwen3 |
| self, | ||
| model: PreTrainedModel, | ||
| max_batch_size: int = 1, | ||
| max_cache_len: int = 4096, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i guess you are relying on generation config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me as well. can we merge
|
@zucchini-nlp am I able to merge this myself or do I have to wait for someone from HF to merge it? |
|
Oke, let;s merge it |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
The export tests are failing several models after this PR. @jackzhxng will you have time to take a look? 👀 |
|
Yup, let me fix that |
|
@zucchini-nlp can I confirm that the error you are seeing is this? Also is there a way for me to see the status / logs from these periodic slow tests? |
|
Yep, exactly. I am attaching the job links for all failed tests in case it helps. These tests are run everyday and we usually get pinged internally by a bot when anything new starts failing. Not sure if we can add you there. Running all possibly affected tests before merging a PR is fine imo, and if anything fails even after that on export tests then we will tag you. Same as for this PR :) |
|
@zucchini-nlp fix here! #40261 |
What does this PR do?
Allows specifying
inputs_embedsinTorchExportableModule's in order to support export of multimodal model's text decoders.Adds
configandgeneration_configto the constructors to support multimodal models since theTorchExportableModulewill wrap the nested text decoder model, which doesn't have its config and generation config as attributes.e.g. for exporting Voxtral's text decoder we need to:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
@echarlaix @michaelbenayoun @zucchini-nlp