huggingface / alignment-handbook

Robust recipes to align language models with human and AI preferences
https://huggingface.co/HuggingFaceH4
Apache License 2.0
4.53k stars 393 forks source link

Different dtype while saving optimizer with FSDP #153

Closed heraclex12 closed 5 months ago

heraclex12 commented 5 months ago

Hi,

I'm trying to train Mixtral 8x22B using FSDP. I followed the zephyr-141b-A35b recipe to create my training config below:

model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

dataset_mixer:
  HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
preprocessing_num_workers: 12

bf16: true
do_eval: false
evaluation_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 2.0e-5
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_seq_length: 2048
num_train_epochs: 3
optim: adamw_bnb_8bit
output_dir: ./models/zephyr-sft-141b-A35b
overwrite_output_dir: true
per_device_train_batch_size: 1
remove_unused_columns: true
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 5
seed: 42
warmup_steps: 0.1

However, there was a saving error that happened while FSDP was saving the optimizer. What did I do wrong with the config?

Here is the log:

[-01:2]:    self._save_checkpoint(model, trial, metrics=metrics)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-01:2]:    self._save_optimizer_and_scheduler(output_dir)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-01:2]:    save_fsdp_optimizer(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-01:2]:    optim_state = FSDP.optim_state_dict(model, optimizer)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-01:2]:    return FullyShardedDataParallel._optim_state_dict_impl(
[-02:1]:Traceback (most recent call last):
[-02:1]:  File "/home/user/alignment-handbook/scripts/run_sft.py", line 233, in <module>
[-02:1]:    main()
[-02:1]:  File "/home/user/alignment-handbook/scripts/run_sft.py", line 188, in main
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
[-01:2]:    return _optim_state_dict(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[-02:1]:    train_result = trainer.train(resume_from_checkpoint=checkpoint)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
[-02:1]:    output = super().train(*args, **kwargs)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
[-02:1]:    return inner_training_loop(
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop
[-02:1]:    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2588, in _maybe_log_save_evaluate
[-01:2]:    return func(*args, **kwargs)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
[-01:2]:    fsdp_osd_state = convert_fn(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
[-01:2]:    _gather_all_orig_param_state(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
[-01:2]:    output_states = _allgather_orig_param_states(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
[-01:2]:    dtype, state_buffers = _convert_all_state_info(
[-02:1]:    self._save_checkpoint(model, trial, metrics=metrics)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-02:1]:    self._save_optimizer_and_scheduler(output_dir)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-02:1]:    save_fsdp_optimizer(
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-02:1]:    optim_state = FSDP.optim_state_dict(model, optimizer)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-02:1]:    return FullyShardedDataParallel._optim_state_dict_impl(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1378, in _convert_all_state_info
[-01:2]:    assert dtype == info.dtype
alvarobartt commented 5 months ago

Hi here @heraclex12 I believe you'll need to add these lines within the run_sft.py

https://github.com/huggingface/alignment-handbook/blob/70769f9e9ba41c7f08ba6c4ff3725441b68b7ca3/scripts/run_orpo.py#L235-L236

cc @lewtun

heraclex12 commented 5 months ago

I see. Thank you for your support. This could help me fix this issue.