KEMBAR78
set eos_token_id to None to generate until max length by ydshieh · Pull Request #16989 · huggingface/transformers · GitHub
Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 28, 2022

What does this PR do?

Update check_encoder_decoder_model_generate to generate until max length.
Otherwise, this check

self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))

might fail.

Remark

In generate(), we have

eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id

So I think the (original) logic about Generate until max length in check_encoder_decoder_model_generate should be updated too. The case won't really happen in the tests, but in general, config might still have eos_token_id.

I also leave the corresponding flax tests untouched for now.

This PR will fix

FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::Swin2BartModelTest::test_encoder_decoder_model_generate

tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py:280: in check_encoder_decoder_model_generate
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
AssertionError: torch.Size([13, 2]) != (13, 20)

@ydshieh ydshieh changed the title set eos_token_id to None to generate until max length in check_encode… set eos_token_id to None to generate until max length Apr 28, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 28, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@ydshieh ydshieh merged commit 5af5735 into huggingface:main Apr 28, 2022
@ydshieh ydshieh deleted the fix_check_encoder_decoder_model_generate branch April 28, 2022 17:47
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants