KEMBAR78
Remove deprecate method and attirbute in `LRScheduler` by zeshengzong · Pull Request #147301 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Feb 17, 2025

Following #99270 suggestion, remove deprecate method LRScheduler.print_lr


BC-breaking note

LRScheduler.print_lr() along with the verbose kwarg to the LRScheduler constructor has been deprecated since release 2.2. Please use LRScheduler.get_last_lr() to access the learning rate instead.

print_lr and verbose were confusing, not properly documented and were little used, as described in #99270, so we deprecated them in 2.2. Now, we complete the deprecation by removing them completely. To access and print the learning rate of a LRScheduler:

In 2.6.0

optim = ...
lrsched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, verbose=True)
// lrsched will internally call print_lr

In 2.7.0

optim = ...
lrsched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim)
print(lrsched.get_last_lr())

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147301

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b3e60f4 with merge base 3ca1a25 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong zeshengzong marked this pull request as ready for review February 17, 2025 06:49
@vadimkantorov
Copy link
Contributor

I wonder if self.verbose field there can also be removed, it's passed quite a long time since its deprecation

@janeyx99
Copy link
Contributor

Thanks for following up on this--yes let's remove everything associated with this deprecation, including the verbose param

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 18, 2025
@zeshengzong zeshengzong changed the title Remove deprecate method LRScheduler.print_lr Remove deprecate method and attirbute in LRScheduler Feb 19, 2025
@zeshengzong
Copy link
Contributor Author

Should remove epoch in step method as well?

def step(self, epoch: Optional[int] = None):
"""Perform a step."""
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
warnings.warn(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif not getattr(self.optimizer, "_opt_called", False):
warnings.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
self._step_count += 1
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float], self._get_closed_form_lr())
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
if isinstance(param_group["lr"], Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr
self._last_lr: list[float] = [
group["lr"] for group in self.optimizer.param_groups
]

@zeshengzong
Copy link
Contributor Author

@janeyx99 Hi, please check changes when available, thanks!

@janeyx99
Copy link
Contributor

janeyx99 commented Mar 3, 2025

Should remove epoch in step method as well?

no this one is harder to remove

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@janeyx99
Copy link
Contributor

janeyx99 commented Mar 3, 2025

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased opt/deprecate/LRScheduler onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout opt/deprecate/LRScheduler && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the opt/deprecate/LRScheduler branch from f3a2951 to b3e60f4 Compare March 3, 2025 21:36
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@zeshengzong
Copy link
Contributor Author

Seems need help skip BC Lint check, thanks! @janeyx99

@janeyx99 janeyx99 added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Mar 4, 2025
@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 5, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@zeshengzong zeshengzong deleted the opt/deprecate/LRScheduler branch March 5, 2025 06:34
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 11, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 12, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
clayne added a commit to clayne/SimpleTuner that referenced this pull request May 12, 2025
* PyTorch 2.7 removed the deprecated argument entirely:
  pytorch/pytorch#147301
rfdougherty added a commit to rfdougherty/k-diffusion that referenced this pull request Jul 21, 2025
The verbose flag was deprecated in torch 2.2 and removed in 2.7. See:
pytorch/pytorch#147301
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants