-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Improve PT/TF equivalence test #16557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
ea4923f to
df2fc56
Compare
5fa6bd2 to
b1c194d
Compare
tests/clip/test_modeling_tf_clip.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need this anymore - the test in TF common can handle nested outputs, including instances of ModelOutput.
tests/led/test_modeling_tf_led.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done before to make TF-LED having a strong test, while the common version was still a loose test.
Now the common test is (very) strong, we no longer need this test in TF-LED test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I add import torch here without is_torch_available or require_torch? This method will be called only inside test_pt_tf_model_equivalence, which is already decorated with is_pt_tf_cross_test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it really matters as it is indeed already decorated with the is_pt_tf_cross_Test. We don't have a convention set, so feel free to choose the simplest approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the specific part for LXMERT test.
(It is possible to move this part to the common PT/TF test method. But I think it's fine/better to overwrite here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. The new version uses
elif tf_inputs_dict[key].dtype.is_floating:I find it's cleaner and more general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new version, this is handled in prepare_pt_inputs_from_tf_inputs.
if isinstance(value, dict):
pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new version, we only need to overwrite prepare_pt_inputs_from_tf_inputs, because that is the place with actual differences from the common version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to call super() here, because the difference is only about adding a noise argument in the block above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just need to overwrite check_pt_tf_models.
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we should test this argument. I think it is not worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now sure why it was added, but it doesn't look useful I agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was added by me during the process: sometimes I passed the wrong arguments and got errors.
However, those arguments are unlikely to be used by anyone else (unless someone want to change check_pt_tf_outputs)
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make the failure message more informative by adding the corresponding tensor name, like
output.hidden_states
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning those. It's great we can remove some model-specific code to rely on the generic common tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now sure why it was added, but it doesn't look useful I agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, it makes writing tests for edge cases much easier 🚀
cdae60f to
b703e6c
Compare
|
(just rebase on main - no real change since your last review) |
|
Merge now. Don't hesitate to leave comments in any :-) |
* add error message * Use names in the error message * allow ModelOutput * rename to check_pt_tf_outputs and move outside * fix style * skip past_key_values in a better way * Add comments * improve code for label/loss * make the logic clear by moving the ignore keys out * fix _postprocessing_to_ignore * fix _postprocessing_to_ignore: create new outputs from the remaining fields * ignore past_key_values in TFGPT2 models for now * make check_pt_tf_outputs better regarding names * move check_pt_tf_models outside * rename methods * remove test_pt_tf_model_equivalence in TFCLIPModelTest * Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence * move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models * Fix quality * Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence * Fix quality * fix * fix style * Clean-up TFLEDModelTest.test_pt_tf_model_equivalence * Fix quality * add docstring * improve comment Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
What does this PR do?
Improve PT/TF equivalence test.
To make the review a bit easier for you, I made some comments. And here are a summary of changes:
test_pt_tf_model_equivalencein TensorFlowLEDandCLIPare removed: the common one can handle it.test_pt_tf_model_equivalencein TensorFlowLXMERTandViTMAEare removed: we only need to overwriteprepare_pt_inputs_from_tf_inputsforLXMERTcheck_pt_tf_modelsforViTMAETFModelTesterMixin.test_pt_tf_model_equivalence_make_attention_mask_non_null_postprocessing_to_ignore_test_casescheck_pt_tf_outputs:ModelOutput(for CLIP model)output.hidden_statesoroutput.text_model_output.attentions_1Once this PR is approved/merged: