-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Closed
Copy link
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
"""
torchrun --standalone --nproc_per_node=2 repro_dcp_compile.py
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(4, 4)
self.lin2 = nn.Linear(4, 4)
self.register_buffer("buf", torch.randn((4,)), persistent=False)
self.weight = nn.Parameter(torch.randn((4, 4)))
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
gpu_id = int(os.environ["LOCAL_RANK"])
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
model = Model()
model = torch.compile(model)
sharded_sd = get_model_state_dict(model)
set_model_state_dict(model, sharded_sd)
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/users/andgu/pytorch/repro_dcp_compile.py", line 36, in <module>
[rank0]: set_model_state_dict(model, sharded_sd)
[rank0]: File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 853, in set_model_state_dict
[rank0]: return _load_model_state_dict(model, model_state_dict, info)
[rank0]: File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 416, in _load_model_state_dict
[rank0]: state_dict[fqn_with_prefix] = state_dict.pop(fqn)
[rank0]: KeyError: 'buf'
set_model_state_dict calls into _load_model_state_dict, which iterates over named_buffers(). For a compiled module, fqns and fqns_with_prefix always mismatch, so _load_model_state_dict will try to reassign from the FQN without prefix to the one with prefix. However, this does not account for non-persistent buffers not existing in the state dict.
One solution could be just to continue if fqn not in state_dict.
cc @LucasLLC
Metadata
Metadata
Assignees
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module