In fine-tuning cases, you might want to save a subset of your model to reduce the size of your checkpoints. This is particularly important when techniques such as LoRA are used with very large models.
The suggested way to do this is to filter the keys of the model's state_dict
However, this seems to break FSDP loading:
import os
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.checkpoint import FileSystemReader, load_state_dict, FileSystemWriter, save_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import StateDictType
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(100, 50, bias=False)
self.l2 = nn.Linear(50, 1, bias=False)
def work(rank):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "1234"
dist.init_process_group("nccl", world_size=1, rank=rank)
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
model = MyModel().to(device)
model = FSDP(model)
path = "tmp/pytorch_debug_sharded"
with FSDP.state_dict_type(module=model, state_dict_type=StateDictType.SHARDED_STATE_DICT):
sd = model.state_dict()
print(list(sd))
# Trim off some layers
del sd["l2.weight"]
writer = FileSystemWriter(path=path, single_file_per_rank=True)
save_state_dict(sd, writer)
reader = FileSystemReader(path=path)
with FSDP.state_dict_type(module=model, state_dict_type=StateDictType.SHARDED_STATE_DICT):
holder_state = model.state_dict()
load_state_dict(holder_state, reader)
model.load_state_dict(holder_state)
print("good!")
def run():
mp.spawn(work, nprocs=1)
if __name__ == "__main__":
run()
Process SpawnProcess-1:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/carmocca/git/lightning/kk.py", line 44, in work
load_state_dict(holder_state, reader)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 111, in load_state_dict
central_plan = distW.reduce_scatter("plan", local_step, global_step)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 200, in reduce_scatter
raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 173, in reduce_scatter
local_data = map_fun()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 101, in local_step
local_plan = planner.create_local_plan()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 199, in create_local_plan
return create_default_local_load_plan(self.state_dict, self.metadata)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 255, in create_default_local_load_plan
md = metadata.state_dict_metadata[fqn]
KeyError: 'l2.weight'
Traceback (most recent call last):
File "/home/carmocca/git/lightning/kk.py", line 55, in <module>
run()
File "/home/carmocca/git/lightning/kk.py", line 51, in run
mp.spawn(work, nprocs=1)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
while not context.join():
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 149, in join
raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 1
🐛 Describe the bug
In fine-tuning cases, you might want to save a subset of your model to reduce the size of your checkpoints. This is particularly important when techniques such as LoRA are used with very large models.
The suggested way to do this is to filter the keys of the model's
state_dict
However, this seems to break FSDP loading:
A related feature request of mine is https://github.com/pytorch/pytorch/issues/103136 where I asked if FSDP could be made to work if the model didn't include all layers in the
state_dict
Versions
cc @zhaojuanmao @mrshenli @rohan-varma @awgu