pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.09k stars 22.67k forks source link

FSDP loading with a partial state triggers KeyError #105379

Open carmocca opened 1 year ago

carmocca commented 1 year ago

🐛 Describe the bug

In fine-tuning cases, you might want to save a subset of your model to reduce the size of your checkpoints. This is particularly important when techniques such as LoRA are used with very large models.

The suggested way to do this is to filter the keys of the model's state_dict

However, this seems to break FSDP loading:

import os

import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.checkpoint import FileSystemReader, load_state_dict, FileSystemWriter, save_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import StateDictType

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(100, 50, bias=False)
        self.l2 = nn.Linear(50, 1, bias=False)

def work(rank):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "1234"
    dist.init_process_group("nccl", world_size=1, rank=rank)

    torch.cuda.set_device(rank)
    device = torch.device("cuda", rank)

    model = MyModel().to(device)
    model = FSDP(model)

    path = "tmp/pytorch_debug_sharded"

    with FSDP.state_dict_type(module=model, state_dict_type=StateDictType.SHARDED_STATE_DICT):
        sd = model.state_dict()

    print(list(sd))
    # Trim off some layers
    del sd["l2.weight"]

    writer = FileSystemWriter(path=path, single_file_per_rank=True)
    save_state_dict(sd, writer)

    reader = FileSystemReader(path=path)
    with FSDP.state_dict_type(module=model, state_dict_type=StateDictType.SHARDED_STATE_DICT):
        holder_state = model.state_dict()
        load_state_dict(holder_state, reader)
        model.load_state_dict(holder_state)

    print("good!")

def run():
    mp.spawn(work, nprocs=1)

if __name__ == "__main__":
    run()
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/carmocca/git/lightning/kk.py", line 44, in work
    load_state_dict(holder_state, reader)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 111, in load_state_dict
    central_plan = distW.reduce_scatter("plan", local_step, global_step)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 200, in reduce_scatter
    raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 173, in reduce_scatter
    local_data = map_fun()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 101, in local_step
    local_plan = planner.create_local_plan()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 199, in create_local_plan
    return create_default_local_load_plan(self.state_dict, self.metadata)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 255, in create_default_local_load_plan
    md = metadata.state_dict_metadata[fqn]
KeyError: 'l2.weight'

Traceback (most recent call last):
  File "/home/carmocca/git/lightning/kk.py", line 55, in <module>
    run()
  File "/home/carmocca/git/lightning/kk.py", line 51, in run
    mp.spawn(work, nprocs=1)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 149, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 1

A related feature request of mine is https://github.com/pytorch/pytorch/issues/103136 where I asked if FSDP could be made to work if the model didn't include all layers in the state_dict

Versions

torch                        2.1.0.dev20230616+cu118

cc @zhaojuanmao @mrshenli @rohan-varma @awgu

awgu commented 1 year ago

cc @fegin since we have talked about this partial load question