hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
34.57k stars 4.26k forks source link

--shift-atten with bf16 giving error #2952

Closed xhz0809 closed 7 months ago

xhz0809 commented 7 months ago

Reminder

Reproduction

sft code:

CUDA_VISIBLE_DEVICES=1 python src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path "meta-llama/Llama-2-7b-hf" \
    --dataset  ICLR2023\
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir /data/hanzi/Weakness/checkpoints/llamaf/llama2_7b  \
    --overwrite_cache \
    --max_length 8192 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 10 \
    --plot_loss \
    --bf16 \
    --shift_attn

error message:

Traceback (most recent call last):
  File "/data/hanzi/Weakness/llamaf/src/train_bash.py", line 14, in <module>
    main()
  File "/data/hanzi/Weakness/llamaf/src/train_bash.py", line 5, in main
    run_exp()
  File "/data/hanzi/Weakness/llamaf/src/llmtuner/train/tuner.py", line 32, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/data/hanzi/Weakness/llamaf/src/llmtuner/train/sft/workflow.py", line 71, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/trainer.py", line 2118, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/trainer.py", line 3036, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/trainer.py", line 3059, in compute_loss
    outputs = model(**inputs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/peft/peft_model.py", line 1091, in forward
    return self.base_model(
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 160, in forward
    return self.model.forward(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 990, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  File "/data/Hanzi/env/llamaf/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1076, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

Expected behavior

Thanks for solving the previous --shift-attnissue. After pulling the newer version, the above error occurs. Deleting bf16 or using fp 16 works fine.

System Info

CUDA version: 11.7 transformers version: 4.38.2 Platform: Linux-5.4.0-171-generic-x86_64-with-glibc2.31 Python version: 3.10.13 Huggingface_hub version: 0.21.4 PyTorch version (GPU?): 2.0.1 (True)

Others

No response

codemayq commented 7 months ago

see https://github.com/pytorch/pytorch/issues/101932 and https://github.com/huggingface/diffusers/issues/3453

It is caused by pytorch library, try to update torch version and try again. (My pytorch version is 2.1.2, which is OK) bless.