KEMBAR78
New features for CodeParrot training script by loubnabnl · Pull Request #16851 · huggingface/transformers · GitHub
Skip to content

Conversation

@loubnabnl
Copy link
Contributor

This PR adds some features to CodeParrot training script.

  • Add TFLOPS to logging
  • Use Accelerate checkpointing and tracking for Wandb and Tensorborad
  • Fix gradient accumulation for DDP (Fix nlp_example accelerate#106)
  • Scale loss approprietly for the last batch
  • Fix typo in the README

cc @lvwerra @LysandreJik

@lvwerra lvwerra self-requested a review April 20, 2022 11:01
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 20, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @loubnabnl, looks pretty clean already! I left a few minor comments, mainly to make the code a bit more concise.

Regarding the saving of the state: I think it is great that we can now save it but I think what is missing currently is the mechanism to start the script from a saved state. I don't think we need to do much and you can probably follow the example here:
https://github.com/huggingface/accelerate/blob/main/examples/complete_nlp_example.py

Comment on lines 228 to 239
elapsed_time_per_iteration = time.time() - t_start
checkpoint_factor = 4 if args.gradient_checkpointing else 3
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
factor = (
24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
)
flops_per_iteration = factor * (
1.0
+ (args.seq_length / (6.0 * config_model.n_embd))
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
)
tflops = flops_per_iteration / (elapsed_time_per_iteration * accelerator.state.num_processes * (10**12))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we move that to a dedicated function? e.g. compute_tflops(elapsed_time, accelerator, args)? It would be nice if the main training loop would stay concise to make it clearer what's going on.

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
accelerator.save_state(args.save_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly does the save_state save? A bunch of files? Maybe we could add them to a folder e.g. args.save_dir + "/state/".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_state returns a bunch of files (model, optimizer ..), I'm now saving them in folders corresponding to the steps to be able to resume training from these steps later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and since save_state already saves the model in the folder step, I now use save_pretrained for the unwrapped model only for the last checkpoint to save model in args.save_dir to load direclty from there later

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
accelerator.save_state(args.save_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only two minor comments and then it is good to go! 🚀



def compute_tflops(elapsed_time, accelerator, args):
config_model = accelerator.unwrap_model(model).config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing: can you add the link to the formula here? either BigScience or the paper itself. So somebody could find out where that black magic formula actually comes from :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

loubnabnl and others added 2 commits April 21, 2022 17:18
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
@lvwerra lvwerra merged commit d918413 into huggingface:main Apr 21, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* add tflops logging and fix grad accumulation

* add accelerate tracking and checkpointing

* scale loss of last batch correctly

* fix typo

* compress loss computation

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* add resume from checkpoint argument

* add load_state accelerate from checkpoint, register lr scheduler and add tflops function

* reformat code

* reformat code

* add condition on path for resume checkpoint

* combine if conditions

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* add source for tflops formula

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants