-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[FlaxBert] Add ForCausalLM #16995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FlaxBert] Add ForCausalLM #16995
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
9c9e49b to
62173c7
Compare
| self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") | ||
| for fx_output, pt_output in zip(fx_outputs, pt_outputs): | ||
| self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) | ||
| self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh is 1e-5 now the default testing precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a bug in the FlaxBertModelTester and fixed! Thresholds now back to 1e-5 and passing (even with the randomly initialised decoder attention mask) :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. So far for anything higher than 1e-5, I was able to find some issues, either in the models, or in the model testes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - @sanchit-gandhi could you check though which models don't pass with 1e-5 and ideally why?
Overall 4e-2 is fine for me though cc @ydshieh what do you think?
Keep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, feel free to merge @sanchit-gandhi
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
What does this PR do?
Adds cross-attention blocks to the following module classes:
Adds the following ForCausalLM model classes:
Adds the following model tests:
Note: FlaxBertForCausalLM is excluded due to the name mismatch with the PyTorch equivalent BertLMHeadModel. It is implicitly tested through the FlaxRobertaForCausalLM model tests, as well as in the following encoder-decoder model tests: