KEMBAR78
Fix t5 shard on TPU Pods by agemagician · Pull Request #16527 · huggingface/transformers · GitHub
Skip to content

Conversation

@agemagician
Copy link
Contributor

The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host.
This pull request fixes this issue by dividing the global batch to each host before it is shared on each host.

Fixes # (issue)
#16470

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Models:

The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host.
This pull request fixes this issue by dividing the global batch to each host before it is shared on each host.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 31, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

This looks good to me!

@patil-suraj @borisdayma - could you take a look here?

@borisdayma
Copy link
Contributor

Yes this approach works!

@borisdayma
Copy link
Contributor

Thinking about it I think there could be some issues with last batch so we probably need to ensure they all have same number of items and that they are multiple of the number of local devices.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR, LGTM!

@patil-suraj patil-suraj merged commit 5e68675 into huggingface:main Apr 11, 2022
@patil-suraj
Copy link
Contributor

@borisdayma
This line already makes sure that all batches are of same length.

train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Fix t5 shard on TPU Pods

The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host.
This pull request fixes this issue by dividing the global batch to each host before it is shared on each host.

* fix style

Co-authored-by: ahmed-elnaggar <ahmed.elnaggar@allianz.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants