KEMBAR78
t5: add conversion script for T5X to FLAX by stefan-it · Pull Request #16853 · huggingface/transformers · GitHub
Skip to content

Conversation

@stefan-it
Copy link
Collaborator

@stefan-it stefan-it commented Apr 20, 2022

Hi,

this PR adds the (long awaited) conversion script from T5X to HF FLAX, previously available in this GIST.

This conversion scripts allows to convert models that were trained with T5X to a FLAX model, so it can be used with Transformers.

Script was road-tested and performance was compared against official T5 (v1.1) checkpoint (because T5 checkpoints can be converted into T5X checkpoints). More information can be found in this issue in T5X upstream repo.

@stefan-it
Copy link
Collaborator Author

/cc @patrickvonplaten @patil-suraj 🤗

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 20, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mille @stefan-it ! @stancld this should be relevant to you as well.

@patil-suraj can you take a look? :-)

Also @stefan-it, do you have a pretrained T5X lying around somewhere for which I could try out the script? :-)

@stefan-it
Copy link
Collaborator Author

I have just tested it with v1.0 and there's a problem with the lm head:

1.0: dict_keys(['decoder_norm', 'layers_0', 'layers_1', 'layers_2', 'layers_3', 'layers_4', 'layers_5', 'relpos_bias'])
1.1: dict_keys(['decoder_norm', 'layers_0', 'layers_1', 'layers_2', 'layers_3', 'layers_4', 'layers_5', 'layers_6', 'layers_7', 'logits_dense', 'relpos_bias'])

So logits_dense is missing in the 1.0 checkpoints and the conversion script can't handle it. I will try to find a solution here and post a short conversion-pipeline guidline soon.

stancld added a commit to stancld/transformers that referenced this pull request Apr 21, 2022
@stefan-it
Copy link
Collaborator Author

stefan-it commented Apr 21, 2022

v1.0 checkpoints can also be converted now. Here are some checks (with final evaluation on downstream task):

Requirements:

pip3 install git+https://github.com/google-research/t5x.git
pip3 install --upgrade tensorstore==0.1.13

Pinned tensorstore version fixes a strange zarr error when loading the checkpoints.

Then clone 1.0 and 1.1 T5X checkpoints:

gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r -n gs://t5-data/pretrained_models/t5x/t5_small .
gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r -n gs://t5-data/pretrained_models/t5x/t5_1_1_small .

Transformer configs can be downloaded from model hub:

curl --silent https://huggingface.co/t5-small/resolve/main/config.json > config_1_0.json
curl --silent https://huggingface.co/google/t5-v1_1-small/resolve/main/config.json > config_1_1.json

Models can be converted via:

python3 convert_t5x_checkpoint_to_flax.py --t5x_checkpoint_path ./t5_small --config_name ./config_1_0.json --flax_dump_folder_path ./t5x_1_0_exported

python3 convert_t5x_checkpoint_to_flax.py --t5x_checkpoint_path ./t5_1_1_small --config_name ./config_1_1.json --flax_dump_folder_path ./t5x_1_1_exported

Then I ran downstream evaluation with the original T5 models (from model hub) and the converted ones on the summarization task, e.g.:

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ./t5_1_0_original \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate \
    --num_train_epochs 1

and:

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path /mnt/transformers/src/transformers/models/t5/t5x_1_0_exported \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ./t5_1_0_converted \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate \
    --num_train_epochs 1

And compared the training losses. Training losses are identical (original vs. converted model).

@stefan-it
Copy link
Collaborator Author

Hi @stancld , thanks for adding the longt5 variant 🤗 Could you also add the patch for v1.0 checkpoints from this commit:

4f36d42

Would be awesome :)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks a lot @stefan-it !

@patil-suraj patil-suraj merged commit cb7e166 into huggingface:main Apr 21, 2022
patil-suraj pushed a commit to stancld/transformers that referenced this pull request May 26, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* t5: add conversion script for T5X to FLAX

* t5: make flake happy

* t5: add copyright message to t5x conversion script

* t5: fix lm head for v1.0 checkpoints
patrickvonplaten added a commit that referenced this pull request Jun 13, 2022
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading #16148

* Add ckpt conversion script accoring to #16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a61982.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading huggingface#16148

* Add ckpt conversion script accoring to huggingface#16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a61982.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
@joytianya
Copy link

Does this script support the transformation of XL or XXL models?

@joytianya
Copy link

joytianya commented Dec 1, 2022

Does this script support the transformation of XL or XXL models?
I tried to generate the following files in /content/flan_t5x_xl_exported, and then I used this below code to load and happen error. How do I solve it?

model = T5ForConditionalGeneration.from_pretrained("/content/flan_t5x_xl_exported", from_flax=True)
# Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in 
# directory /content/flan_t5x_xl_exported.

/content/flan_t5x_xl_exported:
"
*model-00001-of-00002.msgpack
*model-00002-of-00002.msgpack
*model.msgpack.index.json
config.json
"

@stancld
Copy link
Contributor

stancld commented Dec 1, 2022

@joytianya Currently, you cannot use cross-platform loading when the large model is split into multiple files. But this feature is planned soon -- please see #19965

@joytianya
Copy link

joytianya commented Dec 1, 2022

@stefan-it
@stancld
Does the script support T5X converted into pytorch?
if not, Is there any other solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants