KEMBAR78
[DCP] `set_model_state_dict` errors on compiled module with non-persistent buffer · Issue #122792 · pytorch/pytorch · GitHub
Skip to content

[DCP] set_model_state_dict errors on compiled module with non-persistent buffer #122792

@awgu

Description

@awgu
"""
torchrun --standalone --nproc_per_node=2 repro_dcp_compile.py
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(4, 4)
        self.lin2 = nn.Linear(4, 4)
        self.register_buffer("buf", torch.randn((4,)), persistent=False)
        self.weight = nn.Parameter(torch.randn((4, 4)))


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)

    model = Model()
    model = torch.compile(model)

    sharded_sd = get_model_state_dict(model)
    set_model_state_dict(model, sharded_sd)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/andgu/pytorch/repro_dcp_compile.py", line 36, in <module>
[rank0]:     set_model_state_dict(model, sharded_sd)
[rank0]:   File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 853, in set_model_state_dict
[rank0]:     return _load_model_state_dict(model, model_state_dict, info)
[rank0]:   File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 416, in _load_model_state_dict
[rank0]:     state_dict[fqn_with_prefix] = state_dict.pop(fqn)
[rank0]: KeyError: 'buf'

set_model_state_dict calls into _load_model_state_dict, which iterates over named_buffers(). For a compiled module, fqns and fqns_with_prefix always mismatch, so _load_model_state_dict will try to reassign from the FQN without prefix to the one with prefix. However, this does not account for non-persistent buffers not existing in the state dict.

One solution could be just to continue if fqn not in state_dict.

cc @LucasLLC

Metadata

Metadata

Assignees

Labels

triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions