KEMBAR78
Enable users to use their own loss functions + deal with prefetching for grad accum by muellerzr · Pull Request #34198 · huggingface/transformers · GitHub
Skip to content

Conversation

@muellerzr
Copy link
Contributor

@muellerzr muellerzr commented Oct 16, 2024

What does this PR do?

In conjunction with #34191, this PR solves the other half of what's needed:

  1. Letting users pass in their own loss functions directly to the Trainer via compute_loss
  2. Prefetching the first gradient_accumulation_steps worth of data each complete step and marking how many samples were seen (num_items_in_batch), which can be passed to a loss function if it takes in num_items_seen (name TBD)

A bit of feedback needed we need to coordinate:

  • Should it be called num_items_in_batch and then passed through to the loss functions as such? Or is there a better name we can think of

Fixes huggingface/trl#2175

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LysandreJik @ArthurZucker

@muellerzr muellerzr marked this pull request as ready for review October 16, 2024 17:29
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, IMO a regression test on the grad norms could be fairly nice!

Comment on lines 2463 to 2472
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this more readable!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean did this one 🫠

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can split in 3-4 lines 🎐

Comment on lines 3644 to 3645
if (self.label_smoother is not None or self.compute_loss is not None) and "labels" in inputs:
labels = inputs.pop("labels")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmm if people don't pass a loss, we won't use the model's default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will, it stays in inputs and gets passed to the models forward()

@muellerzr
Copy link
Contributor Author

muellerzr commented Oct 17, 2024

A bit more context, full fine-tuning does NOT SEEM TO BE IMPACTED BY THIS (when padding). I am looking into how this directly affects TRL, however things are not as bad as they may seem.

(Below is an example CausalLM result comparing grad accum 4, bs 8 vs bs 32 both before and after this fix)

image

# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already quickly discussed this with Zach, so this is a more general questions to other reviewers:

Would this line be work for all the different task types we support? Specifically, can we always skip the first item in the sequence, i.e. is the [..., 1:] part valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For casual auto regressive models it works but won't work in other ones

Comment on lines 2463 to 2472
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can split in 3-4 lines 🎐

@muellerzr muellerzr changed the title [DRAFT] Enable users to use their own loss functions + deal with prefetching for grad accum Enable users to use their own loss functions + deal with prefetching for grad accum Oct 17, 2024
Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a denominator change in the test case

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge!

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
@muellerzr muellerzr merged commit 6ba31a8 into main Oct 17, 2024
25 of 26 checks passed
@muellerzr muellerzr deleted the muellerzr-fix-loss-calc branch October 17, 2024 21:01
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Oct 21, 2024
…for grad accum (huggingface#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Cemberk pushed a commit to ROCm/transformers that referenced this pull request Nov 26, 2024
…for grad accum (huggingface#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…for grad accum (huggingface#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
@qgallouedec qgallouedec mentioned this pull request Dec 29, 2024
5 tasks
This was referenced Apr 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient accumulation yields worse results than the equivalent batch size

6 participants