-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🐛 Describe the bug
pytorch/torch/distributed/checkpoint/state_dict.py
Lines 611 to 614 in 585dbfa
| for param_group in optim.param_groups: | |
| if "lr" in param_group: | |
| lrs.append(param_group["lr"]) | |
| param_group["lr"] = 0.0 |
When the original LR is a tensor, _init_optim_state() should respect the tensor-ness of LR. This probably doesn't matter for built-in PyTorch optimizers, but for torchao's low-bit optimizers, LR is expected to be a tensor. There will be an error otherwise. Related to pytorch/ao#1189
I propose to change L614 to
param_group["lr"] = torch.tensor(0.0) if isinstance(param_group["lr"], torch.Tensor) else 0.0
cc: @awgu
Versions
2.6.0.dev20241029
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue