Closed heraclex12 closed 5 months ago
Hi,
I'm trying to train Mixtral 8x22B using FSDP. I followed the zephyr-141b-A35b recipe to create my training config below:
model_name_or_path: mistral-community/Mixtral-8x22B-v0.1 model_revision: main torch_dtype: bfloat16 use_flash_attention_2: true dataset_mixer: HuggingFaceH4/ultrachat_200k: 1.0 dataset_splits: - train_sft preprocessing_num_workers: 12 bf16: true do_eval: false evaluation_strategy: "no" gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 2.0e-5 log_level: info logging_steps: 10 lr_scheduler_type: cosine max_seq_length: 2048 num_train_epochs: 3 optim: adamw_bnb_8bit output_dir: ./models/zephyr-sft-141b-A35b overwrite_output_dir: true per_device_train_batch_size: 1 remove_unused_columns: true push_to_hub: false report_to: "none" save_strategy: "steps" save_steps: 5 seed: 42 warmup_steps: 0.1
However, there was a saving error that happened while FSDP was saving the optimizer. What did I do wrong with the config?
Here is the log:
[-01:2]: self._save_checkpoint(model, trial, metrics=metrics) [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint [-01:2]: self._save_optimizer_and_scheduler(output_dir) [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler [-01:2]: save_fsdp_optimizer( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer [-01:2]: optim_state = FSDP.optim_state_dict(model, optimizer) [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict [-01:2]: return FullyShardedDataParallel._optim_state_dict_impl( [-02:1]:Traceback (most recent call last): [-02:1]: File "/home/user/alignment-handbook/scripts/run_sft.py", line 233, in <module> [-02:1]: main() [-02:1]: File "/home/user/alignment-handbook/scripts/run_sft.py", line 188, in main [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl [-01:2]: return _optim_state_dict( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context [-02:1]: train_result = trainer.train(resume_from_checkpoint=checkpoint) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train [-02:1]: output = super().train(*args, **kwargs) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train [-02:1]: return inner_training_loop( [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop [-02:1]: self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2588, in _maybe_log_save_evaluate [-01:2]: return func(*args, **kwargs) [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict [-01:2]: fsdp_osd_state = convert_fn( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params [-01:2]: _gather_all_orig_param_state( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state [-01:2]: output_states = _allgather_orig_param_states( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states [-01:2]: dtype, state_buffers = _convert_all_state_info( [-02:1]: self._save_checkpoint(model, trial, metrics=metrics) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint [-02:1]: self._save_optimizer_and_scheduler(output_dir) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler [-02:1]: save_fsdp_optimizer( [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer [-02:1]: optim_state = FSDP.optim_state_dict(model, optimizer) [-02:1]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict [-02:1]: return FullyShardedDataParallel._optim_state_dict_impl( [-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1378, in _convert_all_state_info [-01:2]: assert dtype == info.dtype
Hi here @heraclex12 I believe you'll need to add these lines within the run_sft.py
run_sft.py
https://github.com/huggingface/alignment-handbook/blob/70769f9e9ba41c7f08ba6c4ff3725441b68b7ca3/scripts/run_orpo.py#L235-L236
cc @lewtun
I see. Thank you for your support. This could help me fix this issue.
Hi,
I'm trying to train Mixtral 8x22B using FSDP. I followed the zephyr-141b-A35b recipe to create my training config below:
However, there was a saving error that happened while FSDP was saving the optimizer. What did I do wrong with the config?
Here is the log: