hiyouga / LLaMA-Factory

Unify Efficient Fine-Tuning of 100+ LLMs
Apache License 2.0
25.26k stars 3.13k forks source link

训练glm4报错:RuntimeError when using flash attention with 8-bit quantization,同样的参数训llama3则没问题 #4441

Closed fst813 closed 4 days ago

fst813 commented 5 days ago

Reminder

System Info

llama factory版本:0.8.1 transformers:4.41.2 flash-attn:2.5.7

Reproduction

src/train.py \
    --stage sft \
    --model_name_or_path ZhipuAI/glm-4-9b-chat \
    --do_train \
    --preprocessing_num_workers 16 \
    --dataset tmp_train \
    --template llama3_our \
    --finetuning_type lora \
    --lora_target all \
    --output_dir glm4_v7_8_0 \
    --overwrite_cache \
    --overwrite_output_dir \
    --per_device_train_batch_size 3 \
    --gradient_accumulation_steps 3 \
    --flash_attn fa2\
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 250 \
    --cutoff_len 4096 \
    --quantization_bit 8 \
    --save_only_model True \
    --learning_rate 1e-4 \
    --num_train_epochs 5.0 \
    --plot_loss \
    --fp16 \
    --group_by_length \
    --use_fast_tokenizer False \
    --lora_alpha 16 \
    --lora_rank 8 \
    --lora_dropout 0.1
Traceback (most recent call last):
host1:   File "/home/ss/train_frame/LLaMA-Factory/src/train.py", line 16, in <module>
host1:     main()
host1:   File "/home/ss/train_frame/LLaMA-Factory/src/train.py", line 7, in main
host1:     run_exp()
host1:   File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/tuner.py", line 119, in run_exp
host1:     run_exe(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
host1:   File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/tuner.py", line 30, in run_exe
host1:     run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
host1:   File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 73, in run_sft
host1:     train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
host1:     return inner_training_loop(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
host1:     tr_loss_step = self.training_step(model, inputs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
host1:     loss = self.compute_loss(model, inputs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 3264, in compute_loss
host1:     outputs = model(**inputs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
host1:     else self._run_ddp_forward(*inputs, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
host1:     return self.module(*inputs, **kwargs)  # type: ignore[index]
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
host1:     return model_forward(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
host1:     return convert_to_fp32(self.model_forward(*args, **kwargs))
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
host1:     return func(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/peft/peft_model.py", line 1430, in forward
host1:     return self.base_model(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
host1:     return self.model.forward(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 1001, in forward
host1:     transformer_outputs = self.transformer(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 897, in forward
host1:     hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 897, in forward
host1:     hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 712, in forward
host1:     layer_ret = torch.utils.checkpoint.checkpoint(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
host1:     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
host1:     return fn(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
host1:     return fn(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
host1:     ret = function(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 625, in forward
host1:     attention_output, kv_cache = self.self_attention(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
host1:     output = module._old_forward(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 522, in forward
host1:     context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
host1:     return self._call_impl(*args, **kwargs)
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
host1:     return forward_call(*args, **kwargs)
host1:   File "/home/ss/.cache/huggingface/modules/transformers_modules/modeling_chatglm.py", line 320, in forward
host1:     attn_output_unpad = flash_attn_varlen_func(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1066, in flash_attn_varlen_func
host1:     return FlashAttnVarlenFunc.apply(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
host1:     return super().apply(*args, **kwargs)  # type: ignore[misc]
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 581, in forward
host1:     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
host1:   File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
host1:     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
host1: RuntimeError: FlashAttention only support fp16 and bf16 data type

Expected behavior

在之前的issue中看到过同样的问题,解决办法是加--fp16,这个我已经加了--fp16。

Others

No response

hiyouga commented 4 days ago

可能是模型代码的问题,建议用 SDPA attention