huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.69k stars 936 forks source link

Loading a model trained on different nodes with FSDP and SHARDED_STATE_DICT fails (.metadata not found on all nodes) #3010

Open mroberto166 opened 1 month ago

mroberto166 commented 1 month ago

System Info

- `Accelerate` version: 0.32.1
- Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/roberto/.cache/pypoetry/virtualenvs/model-prod-MCJRev74-py3.11/bin/accelerate
- Python version: 3.11.9
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.2+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 944.84 GB
- GPU type: NVIDIA H100 80GB HBM3
- Accelerate Config:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
mixed_precision: bf16
downcast_bf16: 'no'
num_processes: 16
fsdp_config:
  fsdp_activation_checkpointing: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: HYBRID_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
  fsdp_transformer_layer_cls_to_wrap: RotaryDiTBlockV2
main_training_function: main
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

Tasks

Reproduction

Hi everyone, I am recently getting the following error when I try to load a model that has been previously trained with FSDP and SHARDED_STATE_DICT: 0: Traceback (most recent call last): 0: File "/model/backbone/train.py", line 781, in <module> 0: main() 0: File "/model/backbone/train.py", line 751, in main 0: trainer = Trainer( 0: ^^^^^^^^ 0: File "/model/backbone/train.py", line 291, in __init__ 0: self.checkpointer.load_checkpoint( 0: File "/model/checkpointing/checkpointer.py", line 170, in load_checkpoint 0: self.accelerator.load_state(input_dir=node_path) 0: File "/usr/local/lib/python3.11/site-packages/accelerate/accelerator.py", line 3084, in load_state 0: load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i) 0: File "/usr/local/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 146, in load_fsdp_model 0: dist_cp.load_state_dict( 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 31, in load_state_dict 0: return _load_state_dict(state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner) 0: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 164, in _load_state_dict 0: central_plan = distW.reduce_scatter("plan", local_step, global_step) 0: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 200, in reduce_scatter 0: raise result 0: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([8, 9, 10, 11, 12, 13, 14, 15]) 0: Traceback (most recent call last): (RANK 8) 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 173, in reduce_scatter 0: local_data = map_fun() 0: ^^^^^^^^^ 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 150, in local_step 0: metadata = storage_reader.read_metadata() 0: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 0: File "/usr/local/lib/python3.11/site-packages/torch/distributed/checkpoint/filesystem.py", line 497, in read_metadata 0: with (self.path / ".metadata").open("rb") as metadata_file: 0: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 0: File "/usr/local/lib/python3.11/pathlib.py", line 1044, in open 0: return io.open(self, mode, buffering, encoding, errors, newline) 0: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 0: FileNotFoundError: [Errno 2] No such file or directory: '/mnt/fast/slurmwork/runs/2024-08-13-causal-husky-1/loading-checkpoint/node-1/pytorch_model_fsdp_0/.metadata' 0: Traceback (most recent call last): (RANK 9) This is the content of the machine folder /mnt/fast/slurmwork/runs/2024-08-13-causal-husky-1/loading-checkpoint/node-1/pytorch_model_fsdp_0/:

ls -la /mnt/fast/slurmwork/runs/2024-08-13-causal-husky-1/loading-checkpoint/node-1/pytorch_model_fsdp_0/
total 8
drwxr-xr-x 2 root root 4096 Aug 13 10:28 .
drwxr-xr-x 4 root root 4096 Aug 13 10:28 ..
-rw-r--r-- 1 root root    0 Aug 13 10:10 __10_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __11_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __12_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __13_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __14_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __15_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __8_0.distcp
-rw-r--r-- 1 root root    0 Aug 13 10:10 __9_0.distcp

and this is the content of the same folder on the first node:

ls -la /mnt/fast/slurmwork/runs/2024-08-13-causal-husky-1/loading-checkpoint/node-0/pytorch_model_fsdp_0/
total 2617840
drwxr-xr-x 2 root root      4096 Aug 13 10:30 .
drwxr-xr-x 4 root root      4096 Aug 13 10:31 ..
-rw-r--r-- 1 root root   2108952 Aug 13 10:10 .metadata
-rw-r--r-- 1 root root 357688636 Aug 13 10:11 __0_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __1_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __2_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __3_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __4_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __5_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __6_0.distcp
-rw-r--r-- 1 root root 331547496 Aug 13 10:11 __7_0.distcp

As you can see the .metadata misses in the second node folder, while is present in the first. The mode is saved with acclerate.save_state(dir)and loaded with accelerate.load_state():

def checkpoint(self, *, epoch: int, sync: bool = False) -> None:
        """
        Saves the model, optimizer and scheduler state_dict to a checkpoint file
        using the accelerator.save_state() function.
        """

        epoch_path = self.checkpoint_root_path / f"epoch-{epoch}"
        node_path = epoch_path / f"node-{self.machine_rank}"

        if self.accelerator.is_local_main_process:
            node_path.mkdir(parents=True, exist_ok=False)

        # Wait for the main processes to finish creating the checkpoint directories
        self.accelerator.wait_for_everyone()

        self.accelerator.save_state(output_dir=node_path)
        logging.info("Saved checkpoint to %s", node_path)

        # Wait for all processes to finish saving the checkpoint
        self.accelerator.wait_for_everyone()

def load_checkpoint(
        self, *, run_id: str, epoch: int, overwrite_checkpoint: bool = True
    ) -> None:
        """
        Loads the model, optimizer and scheduler state_dict from a checkpoint file
        using the accelerator.load_state() function.
        """
        node_path = self.checkpoint_loading_path / f"node-{self.machine_rank}"

        # Wait for the local main processes to finish downloading the checkpoint
        self.accelerator.wait_for_everyone()

        self.accelerator.load_state(input_dir=node_path)
        logging.info("Loaded checkpoint from %s", node_path)

Expected behavior

I would expect the code to find the correct files or not look for non-existing ones

yuyu2015 commented 1 month ago

I have met similar issue when trying to use load_best_model_at_end in training fsdp im multi-node multi-gpu. The worker porcess is trying to locate pytorch_model_fsdp but it is only saved on master process.

github-actions[bot] commented 6 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.