hiyouga / LLaMA-Factory

Efficiently Fine-Tune 100+ LLMs in WebUI (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
30.05k stars 3.7k forks source link

Invalid device string: 'float32' #4698

Closed OnewayLab closed 3 weeks ago

OnewayLab commented 1 month ago

Reminder

System Info

Reproduction

Command

llamafactory-cli train \
    --do_train \
    --stage sft \
    --finetuning_type full \
    --use_unsloth \
    --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
    --flash_attn fa2 \
    --template llama3 \
    --dataset $DATASET \
    --dataset_dir data \
    --cutoff_len 8192 \
    --preprocessing_num_workers 24 \
    --output_dir output/tmp \
    --overwrite_output_dir \
    --save_steps 100 \
    --logging_steps 10 \
    --plot_loss \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 32 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-5 \
    --max_steps 1000 \
    --lr_scheduler_type cosine \
    --warmup_ratio 0.1 \
    --bf16

Error Log

[INFO|trainer.py:3478] 2024-07-06 17:38:36,193 >> Saving model checkpoint to output/tmp/checkpoint-5
[INFO|configuration_utils.py:472] 2024-07-06 17:38:36,207 >> Configuration saved in output/tmp/checkpoint-5/config.json
[INFO|configuration_utils.py:769] 2024-07-06 17:38:36,214 >> Configuration saved in output/tmp/checkpoint-5/generation_config.json
[INFO|modeling_utils.py:2698] 2024-07-06 17:40:06,590 >> The model is bigger than the maximum size per checkpoint (5GB) and is going to be split in 7 checkpoint shards. You can find where each parameters has been saved in the index located at output/tmp/checkpoint-5/model.safetensors.index.json.
Traceback (most recent call last):
  File "/opt/conda/envs/lf/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
  File "~/LLaMA-Factory/src/llamafactory/cli.py", line 111, in main
    run_exp()
  File "~/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "~LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 90, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
  File "<string>", line 367, in _fast_inner_training_loop
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/unsloth/models/llama.py", line 856, in _CausalLM_fast_forward
    outputs = self.model(
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/lf/lib/python3.10/site-packages/unsloth/models/llama.py", line 561, in LlamaModel_fast_forward
    inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
RuntimeError: Invalid device string: 'float32'

Expected behavior

Successful training.

Others

The problem occurs in the first step after saving the first checkpoint. I guess it's because model.config.torch_dtype changes from torch.bfloat16 to a string float32 while saving the checkpoint.

OnewayLab commented 1 month ago

Maybe it's an issue from Hugging Face Transformers. I found the following code in transformers/modeling_utils.py:

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
    def save_pretrained(...):
        ...
        # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
        # we currently don't use this setting automatically, but may start to use with v5
        dtype = get_parameter_dtype(model_to_save)
        model_to_save.config.torch_dtype = str(dtype).split(".")[1]
       ...

What could we do to fix it...

skiiwoo commented 1 month ago

Maybe it's an issue with unsloth. I got same error when use_unsloth: true.

### model
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
flash_attn: fa2
# use_unsloth: true

### method
stage: sft
do_train: true
finetuning_type: full
bf16: true

### dataset
dataset: mine
template: qwen
cutoff_len: 4000
overwrite_cache: true
preprocessing_num_workers: 8

### output
output_dir: Qwen2-1.5B-Instruct
logging_steps: 10
save_steps: 10
save_strategy: steps
plot_loss: true
overwrite_output_dir: true

per_device_train_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 1.0e-4
num_train_epochs: 1
lr_scheduler_type: cosine
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 8
eval_strategy: steps
eval_steps: 30

### log
report_to: wandb
run_name: Qwen2-1.5B-Instruct
relic-yuexi commented 3 weeks ago

@hiyouga This has fixed in unsloth. And it will be useful in next unsloth version or user can fix it follow now: https://github.com/unslothai/unsloth/pull/874#issue-2447675422

OnewayLab commented 3 weeks ago

@hiyouga This has fixed in unsloth. And it will be useful in next unsloth version or user can fix it follow now: unslothai/unsloth#874 (comment)

Thanks for your reminding! I will close this issue.