KEMBAR78
Fix decoding score comparison when using logits processors or warpers by bryant1410 · Pull Request #10638 · huggingface/transformers · GitHub
Skip to content

Conversation

@bryant1410
Copy link
Contributor

@bryant1410 bryant1410 commented Mar 11, 2021

When doing beam search or other decoding search strategies, the logit scores are normalized (with log_softmax) so the comparisons between the beams (hypotheses) are meaningful. However, the logit processors or warpers may change the scores, and thus may not be normalized anymore.

For example, say you have a beam size of 2. During beam search at some point, beam A is better than B (higher score). You use prefix_allowed_tokens_fn, which in turn through a logit processor narrows down the options of the next tokens to only one. Then masks out all tokens with -inf but one. The score vector may look like [-inf, ..., -2.13, ..., -inf]. This is output and now the scores are not normalized anymore. This filter is not applied to B. Now beam search selects B, which actually keeping the hypothesis A meant having the same probability since the normalized vector should have been [-inf, ..., 0, ..., -inf]. In that case, hypothesis A would have been kept (and that's what actually should happen). This erroneous behavior can happen with any logit processor that doesn't normalize its output, which I see it's often the case.

So that's why I moved the log_softmax to after the logit processor/warper application. I also checked if any logit processor needed the normalization for its input. It doesn't seem to be the case (though I'm not 100% sure). They can still individually apply a normalization if they need to. Maybe the documentation could be changed, by the way:

scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

I feel I should tag @patrickvonplaten, @patil-suraj

@bryant1410
Copy link
Contributor Author

The failing test is test_90_generation_from_short_input, which generates "have you ever heard of sam harris? he is an american singer and songwriter. have you heard of him?" instead of "have you ever been to a sam club? it's a great club in the south." or "have you ever heard of sam harris? he's an american singer, songwriter, and actor.".

I honestly don't know what's the expected behavior there, so not sure if it's flaky or not. The weird thing is that this test seems to be greedy search, not beam search.

@bryant1410
Copy link
Contributor Author

Actually, I just looked more closely and the failing test does use beam search (the beam size is specified in the config). This is an example of something that changes since it uses a NoRepeatNGramLogitsProcessor, a MinLengthLogitsProcessor, and a ForcedEOSTokenLogitsProcessor.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@bryant1410
Copy link
Contributor Author

I'm gonna address it, it's been in my mind. Please don't mark it as stale!

@LysandreJik LysandreJik added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 14, 2021
@LysandreJik
Copy link
Member

I've added the WIP label so that the stale bot doesn't close it!

@bryant1410
Copy link
Contributor Author

@patrickvonplaten sorry for the big delay.

I changed the normalization to be a logit warper now. What do you think of it, and its documentation?

Also, what if we set a deprecation for it? And take advantage of some breaking change in the future and make it the default?

@bryant1410
Copy link
Contributor Author

The failing tests are flaky, right?

@patrickvonplaten
Copy link
Contributor

Could we add one tests for the new logits processor as well? :-)

@bryant1410 bryant1410 force-pushed the fix-beam-search-log-softmax branch from 3122f99 to 066b52c Compare February 1, 2022 19:03
@bryant1410 bryant1410 force-pushed the fix-beam-search-log-softmax branch from 35a072a to b9bd8a1 Compare March 28, 2022 20:23
@bryant1410
Copy link
Contributor Author

@patrickvonplaten can you remove the WIP label? This should be done.

Also, the latest time a test failed, it seemed to be flaky. It should be good to go 🚀

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@bryant1410
Copy link
Contributor Author

@patrickvonplaten friendly reminder on this!

@bryant1410
Copy link
Contributor Author

Also, should we add a flag in generate so this logit processor gets added to the list? Such as renormalize_logits.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 6, 2022

PR looks good to go for me - thanks @bryant1410. Yes indeed could you maybe add a flag renormalize_logits to generate()?

@bryant1410
Copy link
Contributor Author

PR looks good to go for me - thanks @bryant1410. Yes indeed could you maybe add a flag renormalize_logits to generate()?

Okay, @patrickvonplaten I did this change.

What do you think about also making renormalize_logits=True in the future? So then adding some deprecation or warning that this value is gonna change? Or that it should be set to False to keep BC?

@bryant1410
Copy link
Contributor Author

Oh, and btw, note I also applied it to the warpers (so it's applied to both the processors and warpers).

@bryant1410
Copy link
Contributor Author

Should the attribute be added to the configs such that the following can be applied?

renormalize_logits if renormalize_logits is not None else self.config.renormalize_logits

@patrickvonplaten
Copy link
Contributor

Should the attribute be added to the configs such that the following can be applied?

renormalize_logits if renormalize_logits is not None else self.config.renormalize_logits

No need for this I think since it's quite a specific logit processor

@patrickvonplaten
Copy link
Contributor

@bryant1410, could you also update RAG's generate method to incorporate you changes? The test currently fails with
TypeError: _get_logits_processor() missing 1 required positional argument: 'renormalize_logits'

It should be easy to adapt here:

pre_processor = self._get_logits_processor(

@bryant1410
Copy link
Contributor Author

@bryant1410, could you also update RAG's generate method to incorporate you changes? The test currently fails with TypeError: _get_logits_processor() missing 1 required positional argument: 'renormalize_logits'

It should be easy to adapt here:

pre_processor = self._get_logits_processor(

Done. What about this?

What do you think about also making renormalize_logits=True in the future? So then adding some deprecation or warning that this value is gonna change? Or that it should be set to False to keep BC?

@bryant1410 bryant1410 changed the title Fix beam search when using logits processors Fix decoding when using logits processors Apr 7, 2022
@bryant1410 bryant1410 changed the title Fix decoding when using logits processors Fix decoding when using logits processors or warpers Apr 7, 2022
@bryant1410 bryant1410 changed the title Fix decoding when using logits processors or warpers Fix score normalization issues when decoding with logits processors or warpers Apr 7, 2022
@bryant1410 bryant1410 changed the title Fix score normalization issues when decoding with logits processors or warpers Fix decoding score comparison when using logits processors or warpers Apr 7, 2022
@patrickvonplaten patrickvonplaten requested a review from gante April 12, 2022 15:10
@patrickvonplaten
Copy link
Contributor

Good for merge for me! Let's see what @gante says

@bryant1410
Copy link
Contributor Author

Good for merge for me! Let's see what @gante says

Okay! What about the comment/idea on making it renormalize_logits=True in the future? So then adding some deprecation or warning that this value is gonna change?

@patrickvonplaten
Copy link
Contributor

Good for merge for me! Let's see what @gante says

Okay! What about the comment/idea on making it renormalize_logits=True in the future? So then adding some deprecation or warning that this value is gonna change?

Don't really think that's possible due to backwards breaking changes tbh

@bryant1410
Copy link
Contributor Author

Don't really think that's possible due to backwards breaking changes tbh

I understand. However, eventually, the breaking change is gonna happen because of some accumulated "debt" that gets big enough, after many different fixes or wanted features. Like it happens in other libraries. It could happen after some major version change (e.g., v5), which it's a great opportunity to change a lot of desired changes that are breaking.

One approach to track this is to deprecate the value and say when it's gonna be changed (e.g., v5). It could be with a warning, some comment in the docstring, or maybe just a doc that tracks down which is gonna be changed. I guess what I'm saying is to add this change to that list (is it worth it, in your opinion?). BTW, do you have in this repo such a list of things that are eventually gonna be changed (maybe implicitly tracked in various comments)?

What are your thoughts? Maybe you think differently?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale is sensible and I'm in favor of the approved changes 👍

To ensure this change stays future-proof, I'd like to discuss an additional change. The new logit processor, when it exists in the list of logit processors to be applied, must be the last one. Should we raise an exception when it isn't? (e.g. it has to be the last one in this list, when it exists) cc @patrickvonplaten

@bryant1410
Copy link
Contributor Author

To ensure this change stays future-proof, I'd like to discuss an additional change. The new logit processor, when it exists in the list of logit processors to be applied, must be the last one. Should we raise an exception when it isn't? (e.g. it has to be the last one in this list, when it exists) cc @patrickvonplaten

Makes sense to me. However, what if the user wants to do something custom, by manually adding this processor logit somewhere? If we add a check and an exception, then the user would face it in this custom scenario. Or maybe it's a bit far-fetched?

@gante
Copy link
Member

gante commented Apr 12, 2022

Makes sense to me. However, what if the user wants to do something custom, by manually adding this processor logit somewhere? If we add a check and an exception, then the user would face it in this custom scenario. Or maybe it's a bit far-fetched?

Uhmm I see. We can go with the low effort, low cost, and low consequence alternative (see the following suggestion)

bryant1410 and others added 2 commits April 12, 2022 17:38
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@gante
Copy link
Member

gante commented Apr 13, 2022

@bryant1410 regarding the renormalize_logits default value, I've added it to a v5 wishlist, to discuss internally when we decide to do the next major change :)

Since there are no other outstanding requests and CI is green, I'm merging the PR 💪

@gante gante merged commit f7196f2 into huggingface:main Apr 13, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…huggingface#10638)

* Normalize using a logits warper

* Add a flag in `generate` to support the logit renormalization

* Add in RAG
@bryant1410 bryant1410 deleted the fix-beam-search-log-softmax branch October 26, 2022 18:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants