KEMBAR78
[FlaxBert] Add ForCausalLM by sanchit-gandhi · Pull Request #16995 · huggingface/transformers · GitHub
Skip to content

Conversation

@sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Apr 28, 2022

What does this PR do?

Adds cross-attention blocks to the following module classes:

  • FlaxBertModule
  • FlaxRobertaModule (in part through copying FlaxBertModule)
  • FlaxBigBirdModule (in part through copying FlaxBertModule)
  • FlaxElectraModule (in part through copying FlaxBertModule)

Adds the following ForCausalLM model classes:

  • FlaxBertForCausalLM
  • FlaxRobertaForCausalLM (in part through copying FlaxBertForCausalLM)
  • FlaxBigBirdForCausalLM (in part through copying FlaxBertForCausalLM)
  • FlaxElectraForCausalLM (in part through copying FlaxBertForCausalLM)

Adds the following model tests:

  • FlaxRobertaForCausalLM
  • FlaxBigBirdForCausalLM
  • FlaxElectraForCausalLM

Note: FlaxBertForCausalLM is excluded due to the name mismatch with the PyTorch equivalent BertLMHeadModel. It is implicitly tested through the FlaxRobertaForCausalLM model tests, as well as in the following encoder-decoder model tests:

  • Bert-2-Bert (encoder-decoder)
  • Wav2Vec2-2-Bert (speech encoder-decoder)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi sanchit-gandhi force-pushed the FlaxBertForCausalLM branch from 9c9e49b to 62173c7 Compare May 2, 2022 11:07
@sanchit-gandhi sanchit-gandhi marked this pull request as ready for review May 2, 2022 12:49
@sanchit-gandhi sanchit-gandhi changed the title [WIP] [FlaxBert] Add ForCausalLM [FlaxBert] Add ForCausalLM May 2, 2022
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh is 1e-5 now the default testing precision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a bug in the FlaxBertModelTester and fixed! Thresholds now back to 1e-5 and passing (even with the randomly initialised decoder attention mask) :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. So far for anything higher than 1e-5, I was able to find some issues, either in the models, or in the model testes.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - @sanchit-gandhi could you check though which models don't pass with 1e-5 and ideally why?

Overall 4e-2 is fine for me though cc @ydshieh what do you think?

@ydshieh
Copy link
Collaborator

ydshieh commented May 2, 2022

Looks good to me - @sanchit-gandhi could you check though which models don't pass with 1e-5 and ideally why?

Overall 4e-2 is fine for me though cc @ydshieh what do you think?

Keep 1e-5 is much better, because so far I can always find some issues when I find something higher than 1e-5 (well, sometimes it took quite some time to figure out)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, feel free to merge @sanchit-gandhi

@sanchit-gandhi sanchit-gandhi merged commit cd9274d into huggingface:main May 3, 2022
@sanchit-gandhi sanchit-gandhi deleted the FlaxBertForCausalLM branch May 3, 2022 09:26
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants