KEMBAR78
Support input_embeds in torch exportable decoders by jackzhxng · Pull Request #39836 · huggingface/transformers · GitHub
Skip to content

Conversation

@jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Aug 1, 2025

What does this PR do?

Allows specifying inputs_embeds in TorchExportableModule's in order to support export of multimodal model's text decoders.

Adds config and generation_config to the constructors to support multimodal models since the TorchExportableModule will wrap the nested text decoder model, which doesn't have its config and generation config as attributes.

e.g. for exporting Voxtral's text decoder we need to:

voxtral = AutoModel.FromPretrained( ... )
TorchExportableModuleForDecoderOnlyLM(
    model=voxtral.language_model,
    config=voxtral.config.text_config,
    generation_config=voxtral.generation_config,
)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@echarlaix @michaelbenayoun @zucchini-nlp

@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch 2 times, most recently from 49e7e97 to 89e912a Compare August 1, 2025 04:29
@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch from 89e912a to 79e095a Compare August 1, 2025 04:31
@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch from 51c518f to 62da12e Compare August 1, 2025 05:39
@jackzhxng jackzhxng changed the title Support input_embeds in torch exportable decoders [WIP] Support input_embeds in torch exportable decoders Aug 1, 2025
@jackzhxng jackzhxng marked this pull request as ready for review August 1, 2025 05:45
@github-actions github-actions bot requested review from MekkCyber and SunMarc August 1, 2025 05:45
@jackzhxng jackzhxng changed the title [WIP] Support input_embeds in torch exportable decoders Support input_embeds in torch exportable decoders Aug 1, 2025
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jackzhxng , I left a few comments below.Not sure this will work with vision multimodal models right now

Comment on lines 59 to 64
if not config:
config = model.config
if not generation_config:
generation_config = model.generation_config

if not hasattr(config, "use_cache") or config.use_cache is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we dont need to explicitly pass configs. In case of multimodals, the text decoder config is available via model.config.get_text_config() which will return the text decoder config for any type of model. And the needed generation config is usually the model's own generation config, not the LM's generation config

So we can do

text_config = model.config.get_text_config()
gen_config = model.generation_config
# NOTE: model.language_model will not have an lm_head for all vision multimodal models
# so we need to support exporting the whole model, but without multimodal inputs
TorchExportableModuleForDecoderOnlyLM(model=voxtral) 

ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
position_ids = cache_position.unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to remove position_ids?

Copy link
Contributor Author

@jackzhxng jackzhxng Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed unnecessary since models like llama already do this in the forward -

position_ids = cache_position.unsqueeze(0)
Should I add it back?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right! As long as it doesn't break export, I'm fine. Wanted to make sure it wasn't deleted accidentally

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jack, are you sure all models to this internally? I remember position ids are used for cache updates and not sure if all models do this. lets make sure to verify this

@jackzhxng
Copy link
Contributor Author

Hi @zucchini-nlp, thanks for the review! We are only keeping the decoder portion of multimodal in transformers right now, the rest of the exportable modules we are keeping in Optimum ET right now (huggingface/optimum-executorch#111) to iterate more quickly since there is a lot of work going on around multimodal at the moment

@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch 2 times, most recently from a2c29fb to 8b06186 Compare August 5, 2025 00:02
@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch from 8b06186 to 14610ed Compare August 5, 2025 00:05
@jackzhxng jackzhxng requested a review from zucchini-nlp August 5, 2025 00:05
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating, overall LGTM! I would like someone from optimum to review the PR as well before merging

ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
position_ids = cache_position.unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right! As long as it doesn't break export, I'm fine. Wanted to make sure it wasn't deleted accidentally

Comment on lines -409 to -588
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if TorchExportableModuleWithHybridCache is used commonly, we can make a small deprecation cycle

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I believe it's just used for Gemma at the moment

@jackzhxng jackzhxng force-pushed the jz/multimodal-decoder branch from 7747302 to 540e187 Compare August 5, 2025 18:40
@github-actions
Copy link
Contributor

github-actions bot commented Aug 5, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma, gemma2, gemma3, llama, olmo, phi3, qwen2, qwen3

self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i guess you are relying on generation config

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me as well. can we merge

@jackzhxng
Copy link
Contributor Author

@zucchini-nlp am I able to merge this myself or do I have to wait for someone from HF to merge it?

@zucchini-nlp
Copy link
Member

Oke, let;s merge it

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) August 7, 2025 08:39
@zucchini-nlp zucchini-nlp merged commit 6121e9e into huggingface:main Aug 7, 2025
19 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

The export tests are failing several models after this PR. @jackzhxng will you have time to take a look? 👀

`pytest -k test_export tests/models/{below_models}`
      "gemma": 1,
       "gemma2": 1,
       "gemma3": 1,
       "llama": 1,
       "olmo": 1,
       "phi3": 1,
       "qwen2": 1,
       "qwen3": 1,
       "smolvlm": 1

@jackzhxng
Copy link
Contributor Author

Yup, let me fix that

@jackzhxng
Copy link
Contributor Author

jackzhxng commented Aug 13, 2025

@zucchini-nlp can I confirm that the error you are seeing is this?

tests/models/qwen3/test_modeling_qwen3.py:301:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/integrations/executorch.py:609: in generate
    result = exported_program.module().forward(
<eval_with_key>.36:398: in forward
    view_1 = torch.ops.aten.view.default(linear, [1, 6, -1, 128]);  linear = None
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <OpOverload(op='aten.view', overload='default')>, args = (tensor([[[ 0.1143,  0.1040,  0.0566,  ..., -0.0137, -0.0889,  0.2158]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), [1, 6, -1, 128]), kwargs = {}

    def __call__(self, /, *args, **kwargs):
>       return self._op(*args, **kwargs)
E       RuntimeError: shape '[1, 6, -1, 128]' is invalid for input of size 2048

Also is there a way for me to see the status / logs from these periodic slow tests?

@zucchini-nlp
Copy link
Member

Yep, exactly. I am attaching the job links for all failed tests in case it helps. These tests are run everyday and we usually get pinged internally by a bot when anything new starts failing. Not sure if we can add you there.

Running all possibly affected tests before merging a PR is fine imo, and if anything fails even after that on export tests then we will tag you. Same as for this PR :)

 "zucchini-nlp": {
        "gemma": {
            "single-gpu": [
                {
                    "test": "tests/models/gemma/test_modeling_gemma.py::GemmaIntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647381350"
                }
            ]
        },
        "gemma2": {
            "single-gpu": [
                {
                    "test": "tests/models/gemma2/test_modeling_gemma2.py::Gemma2IntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647381360"
                }
            ]
        },
        "gemma3": {
            "single-gpu": [
                {
                    "test": "tests/models/gemma3/test_modeling_gemma3.py::Gemma3IntegrationTest::test_export_text_only_with_hybrid_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647381377"
                }
            ]
        },
        "llama": {
            "single-gpu": [
                {
                    "test": "tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647382292"
                }
            ]
        },
        "olmo": {
            "single-gpu": [
                {
                    "test": "tests/models/olmo/test_modeling_olmo.py::OlmoIntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647379951"
                }
            ]
        },
        "phi3": {
            "single-gpu": [
                {
                    "test": "tests/models/phi3/test_modeling_phi3.py::Phi3IntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647380312"
                }
            ]
        },
        "qwen2": {
            "single-gpu": [
                {
                    "test": "tests/models/qwen2/test_modeling_qwen2.py::Qwen2IntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647380556"
                }
            ]
        },
        "qwen3": {
            "single-gpu": [
                {
                    "test": "tests/models/qwen3/test_modeling_qwen3.py::Qwen3IntegrationTest::test_export_static_cache",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647380763"
                }
            ]
        },
        "smolvlm": {
            "single-gpu": [
                {
                    "test": "tests/models/smolvlm/test_modeling_smolvlm.py::SmolVLMForConditionalGenerationIntegrationTest::test_export_smolvlm_text_decoder",
                    "commit": "6121e9e46c4fc4e5c91d9f927aef5490691850cf",
                    "pr_number": 39836,
                    "author": "jackzhxng",
                    "merged_by": "zucchini-nlp",
                    "job_link": "https://github.com/huggingface/transformers/actions/runs/16820781536/job/47647381225"
                }
            ]
        }
    },

@jackzhxng
Copy link
Contributor Author

@zucchini-nlp fix here! #40261

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants