KEMBAR78
Optimize DPO log probability calculation by retaining necessary cache, saving up to 30GB of memory (#1968) by SeungyounShin · Pull Request #1969 · huggingface/trl · GitHub
Skip to content

Conversation

@SeungyounShin
Copy link
Contributor

What does this PR do?

This PR addresses an issue where unnecessary cache clearing was performed during the log probability calculation in the training loop. The original code included calls to torch.cuda.empty_cache() and self.accelerator.free_memory() to avoid out-of-memory (OOM) errors, but these were not necessary and may have introduced performance overhead. By removing these calls, this PR optimizes the memory management during training.

Fixes # (issue)

Motivation and Context

This change improves the efficiency of the log probability calculation during training by eliminating redundant memory management operations. This can lead to better performance and more efficient use of GPU memory.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is welcome to review this PR once the tests have passed. Feel free to tag members or contributors who may be interested in reviewing.

reference_chosen_logps.append(reference_chosen_logp.cpu())
reference_rejected_logps.append(reference_rejected_logp.cpu())

# Unnecessary cache clearing to avoid OOM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this line, GPU memory usage can increase up to 45GB when running a LLaMA-8B model on an 8K context dataset:

Screenshot 1

Simply adding this line reduces and stabilizes memory consumption, keeping it under control:

Screenshot 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this cache occupies memory during training, leading to reduced available memory and, consequently, limiting the context size that can be used for training.

@SeungyounShin SeungyounShin changed the title Fix issue with unnecessary cache clearing during log probability calculation (#1968) Fix issue with unnecessary cache clearing during DPO log probability calculation (#1968) Aug 26, 2024
@SeungyounShin SeungyounShin changed the title Fix issue with unnecessary cache clearing during DPO log probability calculation (#1968) Optimize DPO log probability calculation by retaining necessary cache, saving up to 30GB of memory (#1968) Aug 26, 2024
Copy link
Collaborator

@kashif kashif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif merged commit 1e4fb80 into huggingface:main Aug 26, 2024
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants