KEMBAR78
`torch.distributed.checkpoint.state_dict._init_optim_state()` should respect tensor-ness of lr · Issue #139575 · pytorch/pytorch · GitHub
Skip to content

torch.distributed.checkpoint.state_dict._init_optim_state() should respect tensor-ness of lr #139575

@gau-nernst

Description

@gau-nernst

🐛 Describe the bug

for param_group in optim.param_groups:
if "lr" in param_group:
lrs.append(param_group["lr"])
param_group["lr"] = 0.0

When the original LR is a tensor, _init_optim_state() should respect the tensor-ness of LR. This probably doesn't matter for built-in PyTorch optimizers, but for torchao's low-bit optimizers, LR is expected to be a tensor. There will be an error otherwise. Related to pytorch/ao#1189

I propose to change L614 to

param_group["lr"] = torch.tensor(0.0) if isinstance(param_group["lr"], torch.Tensor) else 0.0

cc: @awgu

Versions

2.6.0.dev20241029

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions