Open nighting0le01 opened 1 day ago
@gau-nernst can you please take a look?: is_pinned is not implemented hence it causes issues when saving optimizer states
Do you have a small reproduction? Yea I think we don't exactly test for saving/loading optimizers in FSDP2. Will add tests for it.
unfortunately i cannot share the orignal code, but this ability to save optimizer states properly with FSDP2 and other parallelism implemented with Dtensor is crucial to make use of these low bit optimizers. @gau-nernst
It looks like you are using pytorch lightning, and it calls torch.distributed.checkpoint.state_dict_saver.save()
. This function requires the tensor subclass to implement aten.is_pinned
.
I tested that the normal torch.save(optim.state_dict(), "state_dict.ckpt")
works fine. @nighting0le01 In the mean time, is it possible for you to switch to the plain torch.save()
for checkpointing?
@awgu What are the benefits of using torch.distributed.checkpoint.state_dict_saver.save()
over the plain torch.save()
? From my understanding of https://pytorch.org/docs/stable/distributed.checkpoint.html, it seems like the former will handle some kind of resharding when loading? Is the saving the same?
Implementing aten.is_pinned
op is simple, but since I'm not too familiar with torch.distributed.checkpoint
, what is the recommended/correct way to save and load (sharded) optim state dict with it? (so that I can add the correct tests) Is it something like this
from torch.distributed.checkpoint import state_dict_saver, state_dict_loader
fsdp_model = ...
fsdp_optim = AdamW8bit(fsdp_model.parameters())
# do some training, so optimizer states are initialized
rank = torch.distributed.get_rank()
state_dict_saver.save(fsdp_optim.state_dict(), checkpoint_id=f"state_dict_rank{rank}.ckpt")
# new sharded optim. optimizer states are not initialized
new_fsdp_optim = AdamW8bit(fsdp_model.parameters())
state_dict = new_fsdp_optim.state_dict()
# this requires aten.detach. and it doesn't seem to load optim state when the new optim state is empty (i.e. not initialized)
state_dict_saver.load(state_dict, checkpoint_id=f"state_dict_rank{rank}.ckpt")
new_fsdp_optim.load(state_dict)
Cannot run FSDP2 with low bit optim from AO