huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.23k stars 26.61k forks source link

Cannot restore FSDP checkpoint with LOCAL_STATE_DICT #30811

Open helloworld1 opened 5 months ago

helloworld1 commented 5 months ago

System Info

Who can help?

@pacman100 @muellerzr

Information

Tasks

Reproduction

I used FSDP with fsdp_state_dict_type = LOCAL_STATE_DICT The accelerate config is like below

compute_environment: LOCAL_MACHINE                                                                                                  
debug: false                                                                                                                        
distributed_type: FSDP                                                                                                              
downcast_bf16: 'no'                                                                                                                 
fsdp_config:                                                                                                                        
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP                                                                                     
  fsdp_backward_prefetch: BACKWARD_PRE                                                                                              
  fsdp_cpu_ram_efficient_loading: true                                                                                              
  fsdp_forward_prefetch: false                                                                                                      
  fsdp_offload_params: false                                                                                                        
  fsdp_sharding_strategy: FULL_SHARD                                                                                                
  fsdp_state_dict_type: LOCAL_STATE_DICT                                                                                            
  fsdp_sync_module_states: true                                                                                                     
  fsdp_use_orig_params: true                                                                                                        
main_training_function: main                                                                                                        
mixed_precision: bf16                                                                                                               
rdzv_backend: c10d                                                                                                                  
same_network: true                                                                                                                  
num_machines: 1                                                                                                                     
num_processes: 1                                                                                                                    
tpu_env: []                                                                                                                         
tpu_use_cluster: false                                                                                                              
tpu_use_sudo: false                                                                                                                 
use_cpu: false   

The checkpoint structure is like below

./trainer_state.json
./rng_state_1.pth
./pytorch_model_fsdp_rank1.bin
./pytorch_model_fsdp_rank0.bin
./pytorch_model_fsdp_rank4.bin
./rng_state_5.pth
./rng_state_4.pth
./rng_state_2.pth
./rng_state_3.pth
./pytorch_model_fsdp_rank6.bin
./rng_state_6.pth
./pytorch_model_fsdp_rank2.bin
./scheduler.pt
./rng_state_7.pth
./pytorch_model_fsdp_rank5.bin
./optimizer_0
./optimizer_0/__7_0.distcp
./optimizer_0/__1_0.distcp
./optimizer_0/.metadata
./optimizer_0/__3_0.distcp
./optimizer_0/__0_0.distcp
./optimizer_0/__4_0.distcp
./optimizer_0/__2_0.distcp
./optimizer_0/__6_0.distcp
./optimizer_0/__5_0.distcp
./pytorch_model_fsdp_rank3.bin
./pytorch_model_fsdp_rank7.bin
./rng_state_0.pth

When I try to restore the checkpoint from

trainer.train(resume_from_checkpoint="/home/user/checkpoint-10") 

I got error

training.py 146 <module>     
main()                                                                                                                              

training.py 125 main                                                                                                                
train_results = trainer.train(resume_from_checkpoint=checkpoint)                                                                    

sft_trainer.py 360 train                                          
output = super().train(*args, **kwargs)                                                                                             

trainer.py 1859 train                                                                                                               
return inner_training_loop(                                                                                                         

trainer.py 2037 _inner_training_loop                              
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

trainer.py 2431 _load_from_checkpoint
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")                                                      

ValueError:                                                                                                                         
Can't find a valid checkpoint at /home/user/checkpoint-10  

If I used SHARDED_STATE_DICT, I don't have this error.

Expected behavior

Expect the checkpoint can be restored

amyeroberts commented 3 months ago

cc @muellerzr @SunMarc

github-actions[bot] commented 3 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.