KEMBAR78
Fix TF_MASKED_LM_SAMPLE by ydshieh · Pull Request #16698 · huggingface/transformers · GitHub
Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 11, 2022

What does this PR do?

Fix TF_MASKED_LM_SAMPLE: there is currently a dimension issue regarding mask_token_index and predicted_token_id, which gives different results between PT/TF masked LM code samples

PT: paris
TF: p a r i s

See below for details.

(This is related to #16523)

PT_MASKED_LM_SAMPLE

from transformers import BertTokenizer, BertForMaskedLM
import torch

mask = "[MASK]",
checkpoint = "bert-base-uncased"

tokenizer = BertTokenizer.from_pretrained(f"{checkpoint}")
model = BertForMaskedLM.from_pretrained(f"{checkpoint}")

inputs = tokenizer(f"The capital of France is {mask}.", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# retrieve index of {mask}
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
expected_output = tokenizer.decode(predicted_token_id)


print(mask_token_index)  # tensor([8]): row dimension from `nonzero()`
print(predicted_token_id)  # tensor([3000])
print(expected_output)  # paris

TF_MASKED_LM_SAMPLE (on main)

from transformers import BertTokenizer, TFBertForMaskedLM
import tensorflow as tf

tokenizer = BertTokenizer.from_pretrained(f"{checkpoint}")
model = TFBertForMaskedLM.from_pretrained(f"{checkpoint}")

inputs = tokenizer(f"The capital of France is {mask}.", return_tensors="tf")
logits = model(**inputs).logits

# retrieve index of {mask}
mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
expected_output = tokenizer.decode(predicted_token_id)

print(mask_token_index)  # tf.Tensor(8, shape=(), dtype=int64): no row dimension
print(predicted_token_id)  # tf.Tensor(3000, shape=(), dtype=int64)
print(tokenizer.decode(predicted_token_id))  # p a r i s (not good)

TF_MASKED_LM_SAMPLE (this PR)

# retrieve index of {mask}
mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
expected_output = tokenizer.decode(predicted_token_id)

print(mask_token_index)  # tf.Tensor([[8]], shape=(1, 1), dtype=int64): with row dimension
print(predicted_token_id)  # tf.Tensor([3000], shape=(1,), dtype=int64)
print(tokenizer.decode(predicted_token_id))  # paris

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 11, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh marked this pull request as ready for review April 11, 2022 11:42
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

@ydshieh ydshieh merged commit 40618ec into huggingface:main Apr 11, 2022
@ydshieh ydshieh deleted the fix_tf_masked_lm_code_sample branch April 11, 2022 16:19
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants