huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.34k stars 875 forks source link

Unable to merge checkpoint use `accelerate merge-weights` #2848

Open helloworld1 opened 3 weeks ago

helloworld1 commented 3 weeks ago

System Info

- `Accelerate` version: 0.31.0
- Platform: Linux-5.15.138.1-4.cm2-x86_64-with-glibc2.35
- `accelerate` bash location: /home/jobuser/.local/bin/accelerate
- Python version: 3.10.2
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.3.0.0+git97ff6cf (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False

Information

Tasks

Reproduction

When I try to merge checkpoint I got the following error

accelerate merge-weights checkpoint-2000 checkpoint-2000-merged
Traceback (most recent call last):
  File "/home/jobuser/.local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/commands/merge.py", line 27, in merge_command
    merge_fsdp_weights(
  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 270, in merge_fsdp_weights
    save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 226, in _distributed_checkpoint_to_merged_weights
    dist_cp_format_utils._load_state_dict(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 208, in _load_state_dict
    central_plan = distW.reduce_scatter("plan", local_step, global_step)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 190, in reduce_scatter
    raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 163, in reduce_scatter
    local_data = map_fun()
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 194, in local_step
    metadata = storage_reader.read_metadata()
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 603, in read_metadata
    with self.fs.create_stream(path, "rb") as metadata_file:
  File "/export/apps/python/3.10/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 359, in create_stream
    with cast(Path, path).open(mode) as stream:
  File "/export/apps/python/3.10/lib/python3.10/pathlib.py", line 1117, in open
    return self._accessor.open(self, mode, buffering, encoding, errors,
FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint-2000/.metadata'

The checkpoint itself can be restored from trainer.train(resume_from_checkpoint="./checkpoint-2000".

The content of the checkpoint looks like below

]$ find checkpoint-2000/
checkpoint-2000/
checkpoint-2000/trainer_state.json
checkpoint-2000/rng_state_1.pth
checkpoint-2000/scheduler.pt
checkpoint-2000/pytorch_model_fsdp_0
checkpoint-2000/pytorch_model_fsdp_0/__1_0.distcp
checkpoint-2000/pytorch_model_fsdp_0/.metadata
checkpoint-2000/pytorch_model_fsdp_0/__0_0.distcp
checkpoint-2000/optimizer_0
checkpoint-2000/optimizer_0/__1_0.distcp
checkpoint-2000/optimizer_0/.metadata
checkpoint-2000/optimizer_0/__0_0.distcp
checkpoint-2000/rng_state_0.pth

CC: @muellerzr

Expected behavior

The merge-weight should be able to merge the weights into FULL_STATE_DICT model.

muellerzr commented 3 weeks ago

@helloworld1 I believe you need to use checkpoint-2000/pytorch_model_fsdp_0 explicitly, not just checkpoint-2000, since we can't tell if you mean the model or the optimizer (and we support both)

muellerzr commented 3 weeks ago

(We can probably make that more explicit by changing that error a bit, cc @SunMarc )

helloworld1 commented 3 weeks ago

I can now successfully merge the weights. Thanks!

helloworld1 commented 3 weeks ago

Noticed that the weights are in pickle format not safetensors. I created this PR to default to safetensors https://github.com/huggingface/accelerate/pull/2853