KEMBAR78
Add Doc Tests for Reformer PyTorch by hiromu166 · Pull Request #16565 · huggingface/transformers · GitHub
Skip to content

Conversation

@hiromu166
Copy link
Contributor

@hiromu166 hiromu166 commented Apr 3, 2022

What does this PR do?

#16292

Fixing doc tests in modeling_reformer.py.

  • ReformerModelWithLMHead.forward
  • ReformerForMaskedLM.forward
  • ReformerForQuestionAnswering.forward
  • ReformerForSequenceClassification.forward

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten
@ydshieh
@patil-suraj

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@hiromu166
Copy link
Contributor Author

I found two errors in the ReformerForMaskedLM.forward example.

  1. config.is_decoder = True in default.
# before
model = ReformerForMaskedLM.from_pretrained("google/reformer-crime-and-punishment")
# AssertionError: If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention.

# after
model = ReformerForMaskedLM.from_pretrained("google/reformer-crime-and-punishment", is_decoder=False)
# It's OK.
  1. tokenizer.mask_token_id is None.
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
# TypeError: 'bool' object is not subscriptable
# Because tokenizer.mask_token_id is None.

@hiromu166
Copy link
Contributor Author

This issue reports same problems.
#10813

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got AssertionError: If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`. with model = ReformerModelWithLMHead.from_pretrained("hf-internal-testing/tiny-random-reformer").

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ReformerModelWithLMHead, we could use the checkpoint google/reformer-crime-and-punishment, and we won't have the issue regarding is_decoder.

Comment on lines 2530 to 2531
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print loss instead of loss.backward() because I got AssertionError: If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode..

Copy link
Collaborator

@ydshieh ydshieh Apr 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is a good idea to set model.train() before loss = model(**inputs, labels=labels).loss?

So the users will know it is required :-)

cc @patrickvonplaten for comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion!
I tried setting model.train(), but got ValueError: If training, sequence length 11 has to be a multiple of least common multiple chunk_length 4. Please consider padding the input to a length of 12. at loss = model(**inputs, labels=labels).loss.

Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed that's very Reformer specific, but nice that you caught it :-) Could we use a sequence length of 12 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment.
I tried a sequence length of 12 like below:

import torch
from transformers import ReformerTokenizer, ReformerForSequenceClassification

tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = ReformerForSequenceClassification.from_pretrained("hf-internal-testing/tiny-random-reformer", problem_type="multi_label_classification")

inputs = tokenizer("Hello, my dog is cute", max_length=12, padding="max_length", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]

# To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
num_labels = len(model.config.id2label)
model = ReformerForSequenceClassification.from_pretrained("hf-internal-testing/tiny-random-reformer", num_labels=num_labels)
model.train()

num_labels = len(model.config.id2label)
labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
    torch.float
)
loss = model(**inputs, labels=labels).loss
loss.backward()

However, I got ValueError: If training, make sure that config.axial_pos_shape factors: (4, 25) multiply to sequence length. Got prod((4, 25)) != sequence_length: 12. You might want to consider padding your sequence length to 100 or changing config.axial_pos_shape. at loss = model(**inputs, labels=labels).loss in this time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, I tried a sequence length of 100.
It seems to work fine!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh @patrickvonplaten
I pushed this change. Could you please check that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look. However, we will wait Patrick's review anyway as he has more knowledge on this model than me (he is currently not available).

@hiromu166 hiromu166 changed the title [WIP] Add Doc Tests for Reformer PyTorch Add Doc Tests for Reformer PyTorch Apr 5, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 6, 2022

Hi, @hiromu166 , could you remind me of the reasons why you need to overwrite the code sample in the model file, instead of just using add_code_sample_docstrings and provide the expected outputs and checkpoints?

(sorry if you already mentioned before!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using replace_return_docstrings here is below:

  • We need to load the model with is_decoder=True. (line:2230)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using replace_return_docstrings here is below:

  • ReformerTokenizer's mask_token is None, so we need to set it. (line: 2360)
  • len(tokenizer.tokenize("The capital of France is [MASK].")) is 16, but len(tokenizer.tokenize("The capital of France is Paris.")) is 17. So, we need to match the lengths. (line: 2378)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Reformer is in fact an auto-regressive generative model, not really intended to be used as Masked LM model.
I saw that you already use hf-internal-testing/tiny-random-reformer for ReformerForMaskedLM, so we don't have issue regarding is_decoder.

We still have issue regarding mask token, so good to use replace_return_docstrings! Thank you.

Copy link
Contributor Author

@hiromu166 hiromu166 Apr 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using replace_return_docstrings here is below:

  • I got an error when calling loss.backward(). So, I replaced it with simply print loss. (line: 2530)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update:

  • adding pad_token to pad inputs. (line: 2509)
  • padding max_length=100 to avoid the error. (line: 2510)
  • model.train() to call loss.backward(). (line: 2526)

@hiromu166
Copy link
Contributor Author

Hi @ydshieh, I commented about the reason for using replace_return_docstrings!
Please check them🙏

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 8, 2022

Hi, @hiromu166

I think we can change back to

_CHECKPOINT_FOR_DOC = "google/reformer-crime-and-punishment"

and use this (_CHECKPOINT_FOR_DOC ) for ReformerModelWithLMHead. If we do so, it should be possible to use add_code_sample_docstrings. Let me know if you still have problem doing so.

For ReformerForMaskedLM, we use https://huggingface.co/hf-internal-testing/tiny-random-reformer, and we can use replace_return_docstrings, as you point out there is mask token issue.

Let's try to make these 2 working first :-)

@hiromu166
Copy link
Contributor Author

Hi @ydshieh, thank you for the suggestion.

OK, I'll change _CHECKPOINT_FOR_DOC like below:

  • "google/reformer-crime-and-punishment" for ReformerModel, ReformerModelWithLMHead.
  • "hf-internal-testing/tiny-random-reformer" for ReformerForMaskedLM, ReformerForSequenceClassification, ReformerForQuestionAnswering to solve randomness and other problems.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 11, 2022

Hi, @hiromu166 Could you try to resolve the conflicts utils/documentation_tests.txt 🙏(you can also pull the latest main branch, rebase this PR on main, fix the conflict, then force push).

(This is not necessary for me to review, but is required to fix the conflict before merging)

I will review this PR tomorrow :-)

@hiromu166 hiromu166 force-pushed the add_doctest_reformer_pt branch from 265fe36 to b2fef30 Compare April 11, 2022 23:40
@hiromu166
Copy link
Contributor Author

Hi @ydshieh, I tried to resolve the conflict like below. Is this OK?

git fetch upstream
git rebase upstream/main
-- fix conflict --
git rebase --continue
git push -f origin add_doctest_reformer_pt

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 12, 2022

Hi @ydshieh, I tried to resolve the conflict like below. Is this OK?

git fetch upstream
git rebase upstream/main
-- fix conflict --
git rebase --continue
git push -f origin add_doctest_reformer_pt

Looks good! I will review this PR now

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @hiromu166, especially for the details that require some attentions 💯

@patrickvonplaten Could you have a final look 🙏

(run locally -> tests pass)

Comment on lines +2494 to +2496
>>> # add pad_token
>>> tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # doctest: +IGNORE_RESULT
>>> inputs = tokenizer("Hello, my dog is cute", max_length=100, padding="max_length", return_tensors="pt")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is required for the next block example, not this current block.
So a good reason to overwrite the code sample here.

>>> model = ReformerForSequenceClassification.from_pretrained(
... "hf-internal-testing/tiny-random-reformer", num_labels=num_labels
... )
>>> model.train() # doctest: +IGNORE_RESULT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I am not sure if we should add this in doc.py. (Currently this is not added in doc.py) Without this, we are not in the train model, and might be a bit misleading.

>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
>>> # mask labels of non-[MASK] tokens
>>> labels = torch.where(
... inputs.input_ids == tokenizer.mask_token_id, labels[:, : inputs["input_ids"].shape[-1]], -100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed the special treatment here is required to make the example run.
(Better to have a good way to handle this in doc.py in the future)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - great job!

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 12, 2022

PR Merged. Thank you again, @hiromu166 ❤️ !

@ydshieh ydshieh merged commit 1bac40d into huggingface:main Apr 12, 2022
@hiromu166
Copy link
Contributor Author

I'm glad it was merged smoothly. Thank you for your cooperation!!

@patrickvonplaten
Copy link
Contributor

Thanks a mille for your contribution @hiromu166 !

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* start working

* fix: ReformerForQA doctest

* fix: ReformerModelWithLMHead doctest

* fix: ReformerModelForSC doctest

* fix: ReformerModelForMLM doctest

* add: documentation_tests.txt

* make fixup

* change: ReformerModelForSC doctest

* change: checkpoint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants