KEMBAR78
[T5 Tokenizer] Model has no fixed position ids - there is no hardcode… by patrickvonplaten · Pull Request #16990 · huggingface/transformers · GitHub
Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

…d max length

What does this PR do?

Fixes #16986

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 28, 2022

The documentation is not available anymore as the PR was closed or merged.

)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
self.assertEqual(batch.input_ids.shape, (2, 8001))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger note that while this IMO fixes a bug (T5 has no fixed max length), it might break backwards compatibility in some edge cases. T5 is used a lot, but still think it's better to correct it here.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a serious change if a user forgot to set a max_length. I understand it fixes a bug, but still would like @LysandreJik 's take on it as well.
Thanks for the PR in any case!

@patrickvonplaten
Copy link
Contributor Author

Oh, that's a serious change if a user forgot to set a max_length. I understand it fixes a bug, but still would like @LysandreJik 's take on it as well. Thanks for the PR in any case!

Agree! We should at least put some ❗ mark in this PR stating that this change could lead to unexpected behavior OOM if max_length is not defined.

@LysandreJik
Copy link
Member

That is definitely a breaking change we want to avoid, IMO. This is likely to break user pipelines with OOM errors or a non consistent number of tokens generated. I'd advocate against this change, and would push to:

  • Document that while the limit is set to 512, T5 can handle longer lengths and encourage users to define their own max lengths
  • Document that this limit will be removed in v5
  • Update the warning just for T5 (see below)
Updating the warning just for T5

You can override this method, which is in tokenization_utils_base.py, in tokenization_t5.py and tokenization_t5_fast.py

def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
corresponding model
Args:
ids (`List[str]`): The ids produced by the tokenization
max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set)
verbose (`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
"will result in indexing errors"
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True

I wouldn't recommend skipping the warning altogether as it still gives important information regarding why the text was eventually truncated or padded. But updating the message makes sense:

    def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
        """
        Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
        corresponding model

        Args:
            ids (`List[str]`): The ids produced by the tokenization
            max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set)
            verbose (`bool`): Whether or not to print more information and warnings.

        """
        if max_length is None and len(ids) > self.model_max_length and verbose:
            if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
                logger.warning(
-                    "Token indices sequence length is longer than the specified maximum sequence length "
-                    f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
-                    "will result in indexing errors"
+                    "The T5 model has no maximum length, but a maximum length is still set for backwards compatibility "
+                    "purposes. To take advantage of the full capabilities of the model, we recommend setting a "
+                    "max_length manually."
                )
            self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented May 2, 2022

Okey took some time to think about it - it's really not easy. I agree @LysandreJik that the previous change (while correct) is too strong as it might break quite some pipelines.

To begin with, note that model_max_length or max_length is only relevant if truncation=True is set. So for all other cases this bug is not relevant.
Now the problem is that by default T5 should not have a set maximum length.
However it is completely reasonable for people to set their own maximum length. To me this means the following: If a user instantiates T5 Tokenizer with model_max_length or passes max_length when encoding/padding, then these values should always be the true max length values and in this case the (incorrectly) hard-coded max length values can be discarded.
Only if a user does not pass max_length when encoding/padding and does not define model_max_length at init, then we should fall back to the (incorrect) hard-coded max length values until v5.

In this PR there two things are changed the 2.) can be considered a small breaking change, but it's really a bug correction for me.

  1. If T5 Tokenizer is instantiated without a custom model_max_length and one of the identifiers for which model_max_length is hardcoded is used, the following warning appears:
This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.

Previously no warning appeared. Note that this warning appears every time at init. However it can be disabled as described above and it's also good to warn the user about upcoming changes this way.

  1. If T5 Tokenizer is instantiated with a model_max_length, this model_max_length always counts even if it's longer than the hardcoded ones. This means the following snippet:
#!/usr/bin/env python3
from transformers import T5TokenizerFast

tok = T5TokenizerFast.from_pretrained("t5-base", model_max_length=600)

out = tok(100 * "hello there is a", padding="longest", truncation=True).input_ids
print(len(out))

does not throw a warning (since the user defines model_max_length) and print a length of 600 (not 512). <- this behavior is different from how it was before.
My rational on changing this is the following:

  • T5's hardcoded model max lengths are wrong, I'm fine with using those if no model_max_length is defined or no max_length is passed
  • But, if a user already passes a model_max_length <- then this should be the only source of truth. E.g. In the example above 600 should be tha max length and not 512.

To be crystal clear 2.) changes the behavior - e.g. run the code snippet before/after the PR, but it's really a bug correction here IMO

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for iterating on this and making it more backward compatible. Your proposed solution looks great!

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your solution looks good to me. Thanks for working on it @patrickvonplaten, LGTM.

if init_max_model_length is not None and init_max_model_length != max_model_length:
return init_max_model_length
elif init_max_model_length is None:
logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a warnings.warn(..., FutureWarning) so that it is correctly displayed as a deprecation warning for the users

@patrickvonplaten
Copy link
Contributor Author

Failure is unrelated

@patrickvonplaten patrickvonplaten merged commit 31616b8 into huggingface:main May 2, 2022
@patrickvonplaten patrickvonplaten deleted the fix_t5_tok_warning branch May 2, 2022 19:27
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
huggingface#16990)

* [T5 Tokenizer] Model has no fixed position ids - there is no hardcoded max length

* [T5 Tokenizer] Model has no fixed position ids - there is no hardcoded max length

* correct t5 tokenizer

* correct t5 tokenizer

* fix test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* finish

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
huggingface#16990)

* [T5 Tokenizer] Model has no fixed position ids - there is no hardcoded max length

* [T5 Tokenizer] Model has no fixed position ids - there is no hardcoded max length

* correct t5 tokenizer

* correct t5 tokenizer

* fix test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* finish

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Warning tells you you will get indexing errors in T5 for going beyond max length

4 participants