KEMBAR78
Improve PT/TF equivalence test by ydshieh · Pull Request #16557 · huggingface/transformers · GitHub
Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 1, 2022

What does this PR do?

Improve PT/TF equivalence test.

To make the review a bit easier for you, I made some comments. And here are a summary of changes:

  • test_pt_tf_model_equivalence in TensorFlow LED and CLIP are removed: the common one can handle it.
  • test_pt_tf_model_equivalence in TensorFlow LXMERT and ViTMAE are removed: we only need to overwrite
    • prepare_pt_inputs_from_tf_inputs for LXMERT
    • check_pt_tf_models for ViTMAE
  • Main changes in TFModelTesterMixin.test_pt_tf_model_equivalence
    • restructure the code into components, so they could be overwritten separately instead of the whole big block
    • move some ugly (temporary) logic blocks outside:
      • _make_attention_mask_non_null
      • _postprocessing_to_ignore_test_cases
    • About check_pt_tf_outputs:
      • it now can handle instances of ModelOutput (for CLIP model)
      • better failure message: print the tensor name where the large diff between PT/TF occurs, like output.hidden_states or output.text_model_output.attentions_1
    • A better way to handle the cases where PT/TF outputs have different keys: we try to test the output values for the common keys in both outputs.

Once this PR is approved/merged:

  • To work on the same PT/TF equivalence test on PT side (should be very quick)
  • To apply the same logic to PT/Flax equivalence test, both on Flax and PT sides.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 1, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need this anymore - the test in TF common can handle nested outputs, including instances of ModelOutput.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done before to make TF-LED having a strong test, while the common version was still a loose test.

Now the common test is (very) strong, we no longer need this test in TF-LED test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I add import torch here without is_torch_available or require_torch? This method will be called only inside test_pt_tf_model_equivalence, which is already decorated with is_pt_tf_cross_test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it really matters as it is indeed already decorated with the is_pt_tf_cross_Test. We don't have a convention set, so feel free to choose the simplest approach.

Comment on lines +497 to +500
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the specific part for LXMERT test.

(It is possible to move this part to the common PT/TF test method. But I think it's fine/better to overwrite here.)

Comment on lines -528 to -532
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. The new version uses

elif tf_inputs_dict[key].dtype.is_floating:

I find it's cleaner and more general.

Comment on lines -534 to -546
Copy link
Collaborator Author

@ydshieh ydshieh Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new version, this is handled in prepare_pt_inputs_from_tf_inputs.

if isinstance(value, dict):
    pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
    pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new version, we only need to overwrite prepare_pt_inputs_from_tf_inputs, because that is the place with actual differences from the common version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to call super() here, because the difference is only about adding a noise argument in the block above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to overwrite check_pt_tf_models.

@ydshieh ydshieh changed the title [WIP] Improve pt tf equiv test Improve PT/TF equivalence test Apr 7, 2022
@ydshieh ydshieh marked this pull request as ready for review April 7, 2022 13:27
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should test this argument. I think it is not worth it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now sure why it was added, but it doesn't look useful I agree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added by me during the process: sometimes I passed the wrong arguments and got errors.

However, those arguments are unlikely to be used by anyone else (unless someone want to change check_pt_tf_outputs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the failure message more informative by adding the corresponding tensor name, like

output.hidden_states

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning those. It's great we can remove some model-specific code to rely on the generic common tests!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now sure why it was added, but it doesn't look useful I agree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, it makes writing tests for edge cases much easier 🚀

@ydshieh ydshieh force-pushed the improve_pt_tf_equiv_test branch from cdae60f to b703e6c Compare April 11, 2022 19:41
@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 11, 2022

(just rebase on main - no real change since your last review)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 11, 2022

Merge now. Don't hesitate to leave comments in any :-)

@ydshieh ydshieh merged commit dce33f2 into huggingface:main Apr 11, 2022
@ydshieh ydshieh deleted the improve_pt_tf_equiv_test branch April 11, 2022 20:19
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* add error message

* Use names in the error message

* allow ModelOutput

* rename to check_pt_tf_outputs and move outside

* fix style

* skip past_key_values in a better way

* Add comments

* improve code for label/loss

* make the logic clear by moving the ignore keys out

* fix _postprocessing_to_ignore

* fix _postprocessing_to_ignore: create new outputs from the remaining fields

* ignore past_key_values in TFGPT2 models for now

* make check_pt_tf_outputs better regarding names

* move check_pt_tf_models outside

* rename methods

* remove test_pt_tf_model_equivalence in TFCLIPModelTest

* Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence

* move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models

* Fix quality

* Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence

* Fix quality

* fix

* fix style

* Clean-up TFLEDModelTest.test_pt_tf_model_equivalence

* Fix quality

* add docstring

* improve comment

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants