-
Notifications
You must be signed in to change notification settings - Fork 30.9k
t5: add conversion script for T5X to FLAX #16853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
t5: add conversion script for T5X to FLAX #16853
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mille @stefan-it ! @stancld this should be relevant to you as well.
@patil-suraj can you take a look? :-)
Also @stefan-it, do you have a pretrained T5X lying around somewhere for which I could try out the script? :-)
I have just tested it with v1.0 and there's a problem with the lm head:
So |
…ch device placement
v1.0 checkpoints can also be converted now. Here are some checks (with final evaluation on downstream task): Requirements: pip3 install git+https://github.com/google-research/t5x.git
pip3 install --upgrade tensorstore==0.1.13 Pinned Then clone 1.0 and 1.1 T5X checkpoints: gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r -n gs://t5-data/pretrained_models/t5x/t5_small .
gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r -n gs://t5-data/pretrained_models/t5x/t5_1_1_small . Transformer configs can be downloaded from model hub: curl --silent https://huggingface.co/t5-small/resolve/main/config.json > config_1_0.json
curl --silent https://huggingface.co/google/t5-v1_1-small/resolve/main/config.json > config_1_1.json Models can be converted via: python3 convert_t5x_checkpoint_to_flax.py --t5x_checkpoint_path ./t5_small --config_name ./config_1_0.json --flax_dump_folder_path ./t5x_1_0_exported
python3 convert_t5x_checkpoint_to_flax.py --t5x_checkpoint_path ./t5_1_1_small --config_name ./config_1_1.json --flax_dump_folder_path ./t5x_1_1_exported Then I ran downstream evaluation with the original T5 models (from model hub) and the converted ones on the summarization task, e.g.: python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir ./t5_1_0_original \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate \
--num_train_epochs 1 and: python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path /mnt/transformers/src/transformers/models/t5/t5x_1_0_exported \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir ./t5_1_0_converted \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate \
--num_train_epochs 1 And compared the training losses. Training losses are identical (original vs. converted model). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks a lot @stefan-it !
…ch device placement
* t5: add conversion script for T5X to FLAX * t5: make flake happy * t5: add copyright message to t5x conversion script * t5: fix lm head for v1.0 checkpoints
* Initial commit * Make some fixes * Make PT model full forward pass * Drop TF & Flax implementation, fix copies etc * Add Flax model and update some corresponding stuff * Drop some TF things * Update config and flax local attn * Add encoder_attention_type to config * . * Update docs * Do some cleansing * Fix some issues -> make style; add some docs * Fix position_bias + mask addition + Update tests * Fix repo consistency * Fix model consistency by removing flax operation over attn_mask * [WIP] Add PT TGlobal LongT5 * . * [WIP] Add flax tglobal model * [WIP] Update flax model to use the right attention type in the encoder * Fix flax tglobal model forward pass * Make the use of global_relative_attention_bias * Add test suites for TGlobal model * Fix minor bugs, clean code * Fix pt-flax equivalence though not convinced with correctness * Fix LocalAttn implementation to match the original impl. + update READMEs * Few updates * Update: [Flax] improve large model init and loading #16148 * Add ckpt conversion script accoring to #16853 + handle torch device placement * Minor updates to conversion script. * Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM * gpu support + dtype fix * Apply some suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * * Remove (de)parallelize stuff * Edit shape comments * Update README.md * make fix-copies * Remove caching logic for local & tglobal attention * Apply another batch of suggestions from code review * Add missing checkpoints * Format converting scripts * Drop (de)parallelize links from longT5 mdx * Fix converting script + revert config file change * Revert "Remove caching logic for local & tglobal attention" This reverts commit 2a61982. * Stash caching logic in Flax model * Make side relative bias used always * Drop caching logic in PT model * Return side bias as it was * Drop all remaining model parallel logic * Remove clamp statements * Move test files to the proper place * Update docs with new version of hf-doc-builder * Fix test imports * Make some minor improvements * Add missing checkpoints to docs * Make TGlobal model compatible with torch.onnx.export * Replace some np.ndarray with jnp.ndarray * Fix TGlobal for ONNX conversion + update docs * fix _make_global_fixed_block_ids and masked neg value * update flax model * style and quality * fix imports * remove load_tf_weights_in_longt5 from init and fix copies * add slow test for TGlobal model * typo fix * Drop obsolete is_parallelizable and one warning * Update __init__ files to fix repo-consistency * fix pipeline test * Fix some device placements * [wip]: Update tests -- need to generate summaries to update expected_summary * Fix quality * Update LongT5 model card * Update (slow) summarization tests * make style * rename checkpoitns * finish * fix flax tests Co-authored-by: phungvanduy <pvduy23@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
* Initial commit * Make some fixes * Make PT model full forward pass * Drop TF & Flax implementation, fix copies etc * Add Flax model and update some corresponding stuff * Drop some TF things * Update config and flax local attn * Add encoder_attention_type to config * . * Update docs * Do some cleansing * Fix some issues -> make style; add some docs * Fix position_bias + mask addition + Update tests * Fix repo consistency * Fix model consistency by removing flax operation over attn_mask * [WIP] Add PT TGlobal LongT5 * . * [WIP] Add flax tglobal model * [WIP] Update flax model to use the right attention type in the encoder * Fix flax tglobal model forward pass * Make the use of global_relative_attention_bias * Add test suites for TGlobal model * Fix minor bugs, clean code * Fix pt-flax equivalence though not convinced with correctness * Fix LocalAttn implementation to match the original impl. + update READMEs * Few updates * Update: [Flax] improve large model init and loading huggingface#16148 * Add ckpt conversion script accoring to huggingface#16853 + handle torch device placement * Minor updates to conversion script. * Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM * gpu support + dtype fix * Apply some suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * * Remove (de)parallelize stuff * Edit shape comments * Update README.md * make fix-copies * Remove caching logic for local & tglobal attention * Apply another batch of suggestions from code review * Add missing checkpoints * Format converting scripts * Drop (de)parallelize links from longT5 mdx * Fix converting script + revert config file change * Revert "Remove caching logic for local & tglobal attention" This reverts commit 2a61982. * Stash caching logic in Flax model * Make side relative bias used always * Drop caching logic in PT model * Return side bias as it was * Drop all remaining model parallel logic * Remove clamp statements * Move test files to the proper place * Update docs with new version of hf-doc-builder * Fix test imports * Make some minor improvements * Add missing checkpoints to docs * Make TGlobal model compatible with torch.onnx.export * Replace some np.ndarray with jnp.ndarray * Fix TGlobal for ONNX conversion + update docs * fix _make_global_fixed_block_ids and masked neg value * update flax model * style and quality * fix imports * remove load_tf_weights_in_longt5 from init and fix copies * add slow test for TGlobal model * typo fix * Drop obsolete is_parallelizable and one warning * Update __init__ files to fix repo-consistency * fix pipeline test * Fix some device placements * [wip]: Update tests -- need to generate summaries to update expected_summary * Fix quality * Update LongT5 model card * Update (slow) summarization tests * make style * rename checkpoitns * finish * fix flax tests Co-authored-by: phungvanduy <pvduy23@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
Does this script support the transformation of XL or XXL models? |
model = T5ForConditionalGeneration.from_pretrained("/content/flan_t5x_xl_exported", from_flax=True)
# Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in
# directory /content/flan_t5x_xl_exported. /content/flan_t5x_xl_exported: |
@joytianya Currently, you cannot use cross-platform loading when the large model is split into multiple files. But this feature is planned soon -- please see #19965 |
@stefan-it |
Hi,
this PR adds the (long awaited) conversion script from T5X to HF FLAX, previously available in this GIST.
This conversion scripts allows to convert models that were trained with T5X to a FLAX model, so it can be used with Transformers.
Script was road-tested and performance was compared against official T5 (v1.1) checkpoint (because T5 checkpoints can be converted into T5X checkpoints). More information can be found in this issue in T5X upstream repo.