eosphoros-ai / DB-GPT-Hub

A repository that contains models, datasets, and fine-tuning techniques for DB-GPT, with the purpose of enhancing model performance in Text-to-SQL
MIT License
1.21k stars 168 forks source link

预测阶段:poetry run sh ./dbgpt_hub/scripts/predict_sft.sh,Killed #269

Open GuokaiLiu opened 3 weeks ago

GuokaiLiu commented 3 weeks ago
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
    --model_name_or_path $model_name_or_path \
    --quantization_bit 4 \
    --do_train \
    --dataset $dataset \
    --max_source_length 2048 \
    --max_target_length 512 \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --template llama2 \
    --lora_rank 32 \
    --lora_alpha 32 \
    --output_dir $output_dir \
    --overwrite_cache \
    --overwrite_output_dir \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --lr_scheduler_type cosine_with_restarts \
    --logging_steps 50 \
    --save_steps 2000 \
    --learning_rate 2e-4 \
    --num_train_epochs 8 \
    --plot_loss >> ${train_log}
    # \
    # --bf16  >> ${train_log}
    # --bf16#v100不支持bf16
CUDA_VISIBLE_DEVICES=0  python dbgpt_hub/predict/predict.py \
    --model_name_or_path /home/lgk/Downloads/CodeLlama-7b-Instruct-hf \
    --template llama2 \
    --finetuning_type lora \
    --predicted_input_filename dbgpt_hub/data/example_text2sql_dev.json \
    --checkpoint_dir dbgpt_hub/output/adapter/CodeLlama-7b-sql-lora \
    --predicted_out_filename dbgpt_hub/output/pred/pred_codellama7b.sql >> ${pred_log}   
(dbgpt_hub) lgk@WIN-20240401VAM:~/Projects/DB-GPT-Hub$ poetry run sh ./dbgpt_hub/scripts/predict_sft.sh
Warning: Found deprecated priority 'default' for source 'mirrors' in pyproject.toml. You can achieve the same effect by changing the priority to 'primary' and putting the source first.
/home/lgk/.conda/envs/dbgpt_hub/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
  warnings.warn(
Loading checkpoint shards:   0%|                                                                                     | 0/2 [00:00<?, ?it/s]

Killed
Oops322 commented 4 days ago

你好,请问解决了吗,我的predict进程也直接被kill了,我怀疑是跟"--quantization_bit 4 \"参数有关

Oops322 commented 4 days ago

然而当我加上了--quantization_bit 4时,又报错误KeyError: 'base_model.model.model.layers.0.self_attn.q_proj.weight'