pytorch / torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
https://pytorch.org/torchsnapshot
Other
146 stars 41 forks source link

Issue Loading FSDP wrapped module using FULL_STATE_DICT type. #141

Open hbikki opened 1 year ago

hbikki commented 1 year ago

🐛 Describe the bug

Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.

Versions: pytorch = 2.0.0+cu117 torchx-nightly>=2023.3.15 torchsnapshot=0.1.0

Host Details: The bellow training is tested on a single node with 8 NPROC_PER_NODE.

Code:

Model training code:

def train() -> None:
    init_process_group(backend="nccl")
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank())
    model = load_model("t5-small")

    fsdp_model = FSDP(
        model,
        auto_wrap_policy=functools.partial(
            transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
        ),
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        device_id=local_rank(),
    )
    <-------training -loop-->
    <-------save_checkpoint-->

stateDictType = FULL_STATE_DICT
related saving/loading code:

  def save_checkpoint() -> None:
        with FSDP.state_dict_type(
            checkpoint.model,
            self.stateDictType):
            Snapshot.take(path=str(save_dir), app_state=app_state)

    def load_checkpoint() -> None:
        with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
            Snapshot(path=str(load_dir)).restore(app_state=app_state)

Error stack trace: https://pastebin.com/ih9qSbwR

.snapshot_metadata for the model on local rank: https://pastebin.com/t6grkKyX

Does anyone know how to resolve this ? thanks!

kiukchung commented 1 year ago

@hbikki can you correctly edit the markdown so that the stacktrace displays in a code block? And could you also include the full stack trace (if its too long feel free to paste bin and provide a link here).

yifuwang commented 1 year ago

Hey @hbikki, could you please share the snapshot metadata in question? It's the .snapshot_metadata file under the snapshot folder/prefix in question.

hbikki commented 1 year ago

Hello, Updated the issue with the requested data, thanks