OpenBMB / ToolBench

[ICLR'24 spotlight] An open platform for training, serving, and evaluating large language model for tool learning.
https://openbmb.github.io/ToolBench/
Apache License 2.0
4.62k stars 397 forks source link

Train script fail to Yi-6B, how to adapt? #201

Open yhyu13 opened 7 months ago

yhyu13 commented 7 months ago

Hi,

I am intersted in applying toolbech dataset to Yi-6B. https://huggingface.co/chargoddard/Yi-6B-Llama

The training script has been slightly modified:

export PYTHONPATH=./ && \
        deepspeed --master_port=20001 toolbench/train/train_lora.py \
                --model_name_or_path /root/CodeSpace/Yi-6B-Llama  \
                --data_path  /root/CodeSpace/data/toolllama_G123_dfs_eval.json \
                --eval_data_path  /root/CodeSpace/data/toolllama_G123_dfs_eval.json \
                --conv_template tool-llama-single-round \
                --bf16 True \
                --output_dir toolYi_6B_llama_lora \
                --num_train_epochs 5 \
                --per_device_train_batch_size 4 \
                --per_device_eval_batch_size 2 \
                --gradient_accumulation_steps 2 \
                --evaluation_strategy "epoch" \
                --prediction_loss_only \
                --save_strategy "epoch" \
                --save_total_limit 8 \
                --learning_rate 0.00005 \
                --weight_decay 0 \
                --warmup_ratio 0.04 \
                --lr_scheduler_type "cosine" \
                --logging_steps 1 \
                --source_model_max_length 4096 \
                --model_max_length 4096 \
                --gradient_checkpointing True \
                --lazy_preprocess True \
                --deepspeed ds_configs/stage2.json \
                --report_to none \

But it turns out to have error:

  File "/root/CodeSpace/ToolBench/toolbench/train/llama_flash_attn_monkey_patch.py", line 28, in forward_2
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[4, 4096, 32, 128]' is invalid for input of size 8388608
  0%|  

Does the flash attention code only adapat to llama2 models but not Yi-6B?

Thanks!

yhyu13 commented 7 months ago

Seems to a issue related to custom flash attention implementation, since hf transformers already support using flash attention2

https://huggingface.co/docs/transformers/perf_infer_gpu_one

I will make a pr for this. I tested it actually work to just use hf transformer's flash attention

pooruss commented 7 months ago

Thanks for your great contribution! We will check the pr.