KEMBAR78
propagate "attention_mask" dtype for "use_past" in OnnxConfig.generate_dummy_inputs by arampacha · Pull Request #17105 · huggingface/transformers · GitHub
Skip to content

Conversation

@arampacha
Copy link
Contributor

What does this PR do?

Fixes #16538

The mask_dtype is propagated to torch.ones() producing "attention_mask" for past_key_values in generate_dummy_inputs call. This ensures the input datatype expected by ONNX model matches default "attention_mask" dtype.
The fix is applied for configs where the pattern was used.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

The existing tests use *OnnxConfig.generate_dummy_inputs method to produce inputs passed to session.run(...) for testing. For this reason the issue was not reported at testing - the same inputs are used for export and testing. I'm not sure if specific tests for inputs datatype is required.

Here is a notebook for verifying the fix works as expected.

Who can review?

@lewtun

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 5, 2022

The documentation is not available anymore as the PR was closed or merged.

@lewtun lewtun requested a review from michaelbenayoun May 6, 2022 10:42
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving all these types for the attention masks @arampacha!

It looks good to me, so gently pinging @patil-suraj and @sgugger for their perspective :)

For context to the reviewers: these changes ensure the data types of input_ids and attention_mask are the same (i.e. ints) when these models are exported to ONNX

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@sgugger sgugger merged commit 0645b07 into huggingface:main May 11, 2022
@arampacha arampacha deleted the causal-lm-with-past-onnx-config branch May 11, 2022 11:54
ArthurZucker pushed a commit to ArthurZucker/transformers that referenced this pull request May 12, 2022
…e_dummy_inputs (huggingface#17105)

* propagate attention_mask dtype

* fixup&style
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…e_dummy_inputs (huggingface#17105)

* propagate attention_mask dtype

* fixup&style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ONNX causal-lm-with-past conversion: attention_mask dtype changed

5 participants