KEMBAR78
Deprecated verbose parameter in LR schedulers by thomasjpfan · Pull Request #111302 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thomasjpfan
Copy link
Contributor

@thomasjpfan thomasjpfan commented Oct 14, 2023

BC Breaking

As of this commit, the verbose parameter of LRScheduler is deprecated and will no longer trigger print statements during execution. Based on the discussion in #100847, having one-off print statements/logging smattered throughout PyTorch is not a future we want to move towards. Instead, we would prefer a consolidated logging system (usually used for debugging). If you would like to print the learning rate during execution, please use get_last_lr().

For example, instead of the following:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True)
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)

Please instead have something like:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)
	print(f"Epoch {epoch} has concluded with lr of {scheduler.get_last_lr()}")

Context

Fixes #100847

This PR follows the comment in #100847 (comment) by deprecating the verbose parameter and removing the print statements. Removing the print statements is technically BC breaking, so I would be okay with putting them back in.

To be less annoying, this PR raises a warning only when verbose is explicitly passed in.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111302

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 32cc192 with merge base 2a271a3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@thomasjpfan thomasjpfan force-pushed the deprecate_verbose_lr_scheduler branch from eeecaa2 to 3125239 Compare October 14, 2023 15:28
@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module topic: deprecation topic category labels Oct 16, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in review. Looks good!

@albanD
Copy link
Collaborator

albanD commented Nov 10, 2023

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 10, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased deprecate_verbose_lr_scheduler onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout deprecate_verbose_lr_scheduler && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Fixes pytorch#100847

This PR follows the comment in pytorch#100847 (comment) by deprecating the `verbose` parameter and removing the print statements. Removing the print statements is technically BC breaking, so I would be okay with putting them back in.

To be less annoying, this PR raises a warning only when `verbose` is explicitly passed in.
Pull Request resolved: pytorch#111302
Approved by: https://github.com/albanD
@janeyx99 janeyx99 added topic: bc breaking topic category and removed topic: deprecation topic category labels Jan 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim topic: bc breaking topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

logger instead of print in lr_scheduler.py

5 participants