huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.69k stars 26.22k forks source link

Trainer/accelerate doesn't save model when using FSDP with SHARDED_STATE_DICT #30491

Closed alexghergh closed 1 month ago

alexghergh commented 4 months ago

System Info

Who can help?

@pacman100

Information

Tasks

Reproduction

Hey folks,

The issue seems to be simple enough. I tried to train an FSDP model using a multi-node setup, with transformers + accelerate. I'm launching my training script (which doesn't seem relevant to the issue for now, so I'll skip it for brevity) using accelerate launch --config-file config.yaml ... python train.py. The config.yaml looks like this:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

The issue, though, is when using transformers's built-in Trainer class and trying to save the model at the end of training. Trying to run:

trainer = Trainer(...)

trainer.train(...)

trainer.save_model(save_dir)

This, however, doesn't seem to save any file in save_dir (as a matter of fact, it doesn't even create the directory).

The problem seems to be in Trainer.save_model(), which tests for FSDP, and then tests for FULL_STATE_DICT. So this leads to the obvious question, what happens to the SHARDED_STATE_DICT models? I didn't manage to find anything about this in the docs.

Am I missing something? Are you supposed to change the state dict to a FULL_STATE_DICT before saving?

Note that checkpoints do indeed seem to save normally (they are sharded, so it seems that every node saves it's part, which is expected). It is just the final save_model call which seems to faulter.

Thanks a lot!

Expected behavior

The model should be saved in a safetensors format, with all the weights, inside save_dir, as described above, regardless of whether accelerate uses FULL_STATE_DICT or SHARDED_STATE_DICT with FSDP.

amyeroberts commented 3 months ago

cc @muellerzr @SunMarc

muellerzr commented 3 months ago

Are you supposed to change the state dict to a FULL_STATE_DICT before saving?

No, we support SHARDED but we may need to upstream this towards the saving, I only recently brought that into the accelerate side. Let me get to this today

carolius commented 2 months ago

Any updates on this?

katzurik commented 2 months ago

Note that checkpoints do indeed seem to save normally

Do you see the checkpoints saved by the trainer? at least during the training period, i do not see any checkpoint saved by the trainer.

alexghergh commented 2 months ago

@katzurik Hmm, now that you mention, I'm double-checking in my mind whether that was really the case. I'm pretty sure the answer is yes (I'm quite sure I would've said otherwise in my initial post if that wasn't the case), but it was quite some time ago and I can't remember exactly.

In the meantime, my team and I moved to manually setting the state dict to FULL at the end (trainer.accelerator.state.fsdp_plugin.set_state_dict_type('FULL_STATE_DICT'), which works exactly as expected.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

farzadab commented 3 days ago

I'd like to bump this issue up since it has not been updated.

The code explicitly ignores the case of SHARDED_STATE_DICT and LOCAL_STATE_DICT as you can see in the code here: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3533-3539

TBH I'm not sure what the correct thing here is since I'm new to the FSDP world, but this seems to be what's needed: https://huggingface.co/docs/accelerate/en/usage_guides/fsdp#saving-and-loading