KEMBAR78
Removed torch.cuda.empty_cache from train loop. by FoamoftheSea · Pull Request #31530 · huggingface/transformers · GitHub
Skip to content

Conversation

@FoamoftheSea
Copy link
Contributor

What does this PR do?

Removes the addition of torch.cuda.empty_cache from the training loop (introduced in #28769).

This line caused training slowdowns observed in issue #31372

While this thread in the PyTorch forums recommends not to use this function because it is slow, it appears many in the comments there still find it necessary to save them from OOMs on their training jobs, so it might be nice to have the option, but users can just add it on their own if they're in a jam.

Fixes # 31372

@muellerz @SunMarc

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for your investigation @FoamoftheSea ! LGTM !

@SunMarc SunMarc requested review from amyeroberts and muellerzr June 21, 2024 09:22
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall agree with it, if users decide this isn't enough, the next step IMO would be a toggle-able "after n steps" do an empty_cache() or some sort, to at least delay it and give users control.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed PR description @FoamoftheSea and the fix ❤️

Agreed - let's remove for now and if we find that users need it we can think about smarter ways to reintroduce

@amyeroberts amyeroberts merged commit 8b7cd40 into huggingface:main Jun 21, 2024
@aliencaocao
Copy link
Contributor

Does it make sense if its added as a TrainingArgument and default to False, but with a tip to turn it on if vram usage is near the limit? It is useful because many OOMs only happen after X unpredictable steps and many don't watch them all the way before going off

This may also cause some behaviour changes where hyperparams/models working previously OOMs after the change

@amyeroberts
Copy link
Contributor

@aliencaocao I'm not sure we necessarily want to actively monitor the memory and trigger a tip (I suspect this is more fiddly and flaky than expected as you have to balance catching in time vs not spamming, making sure values are correct etc.).

A flag which we can configure for clearing after every n-steps seems reasonable. Would you like to open a PR with a proposal and we can iterate from there?

@aliencaocao
Copy link
Contributor

sure i can do it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

cuda.empty_cache in trainer.py slow down training

5 participants