yxli2123 / LoftQ

MIT License
180 stars 15 forks source link

Can't reproduce reported results on GSM8K #5

Closed BaohaoLiao closed 6 months ago

BaohaoLiao commented 7 months ago

Thank you for this great work.

I ran your training script train_gsm8k.sh with only one modification, changing --per_device_train_batch_size 2 and --gradient_accumulation_steps 8 to 1 and 16, since my A100 only has 40GB memory. I also use the --fake_quantization since it's more stable for optimization.

However, my results are: epoch 0: accuracy: 0.244882486732373 epoch 1: accuracy: 0.2880970432145565 epoch 2: accuracy: 0.3055344958301744 epoch 3: accuracy: 0.2979529946929492 epoch 4: accuracy: 0.29492039423805916

There is a huge gap between my best result and your reported result 35.0. May I ask what might be the cause?

yxli2123 commented 7 months ago

For llama-2-7b, please try --learning_rate 3e-4 and for llama-2-13b, please try --learning_rate 1e-4 as reported in the paper.

BaohaoLiao commented 7 months ago

However, From Table 15 of https://openreview.net/pdf?id=LzPWWPAdY4, lr=1e-4. Only the mixed-precision requires 3e-4.

In addition, how should I set --num_iter, if I want to quantize the model by myself? From the paper, I can only find that you set T=1 for BART. For LLAMA and DeBERTa, what is T for the reported results?

BaohaoLiao commented 7 months ago

I have tried to use lr=3e-4 for LLAMA-2-7b. My script is:

accelerate launch $TOOL \
  --fake_quantization \
  --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64 \
  --output_dir $SAVE_DIR \
  --learning_rate 3e-4  \
  --seed 202 \
  --dataset_name gsm8k \
  --dataset_config main \
  --pad_to_max_length \
  --max_source_length 128 \
  --max_target_length 256 \
  --num_train_epochs 5 \
  --per_device_train_batch_size 1 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 16 \
  --with_tracking \
  --report_to wandb 2>&1 | tee $SAVE_DIR/out

I didn't modify other settings. However, the results are even worse than lr=1e-4 (epoch 2: accuracy: 0.3055344958301744):

epoch 0: accuracy: 0.19484457922668688 epoch 1: accuracy: 0.2350265352539803 epoch 2: accuracy: 0.26611068991660347 epoch 3: accuracy: 0.15466262319939347 epoch 4: accuracy: 0.0075815011372251705

I also tried with lr=5e-5. They are: epoch 0: accuracy: 0.2175890826383624 epoch 1: accuracy: 0.2539802880970432 epoch 2: accuracy: 0.3017437452615618 epoch 3: accuracy: 0.3161485974222896 epoch 4: accuracy: 0.29567854435178165

yxli2123 commented 7 months ago

Hi, thanks for pointing it out. I have updated the model on Huggingface and I can reproduce it now. Be sure you are using transformers~=4.31. The latest transformers~=4.35 causes loading error due to the HF hub structure.

PuNeal commented 7 months ago

@BaohaoLiao Hi, have you reproduced results on GSM8K? I also met this problem.

BaohaoLiao commented 7 months ago

Hi @PuNeal, I haven't tried the newly updated HF model yet.

yxli2123 commented 6 months ago

Hi, this issue should be resolved. Please use command in https://github.com/yxli2123/LoftQ/blob/main/scripts/train_gsm8k.sh#L22.

Let me know if there still are other issues.

BaohaoLiao commented 6 months ago

I still couldn't reproduce your reported results. I use your training script from https://github.com/yxli2123/peft/tree/loftq/examples/loftq_finetuning

python train_gsm8k_llama.py \
    --model_name_or_path LoftQ/Llama-2-13b-hf-4bit-64rank \
    --output_dir exp_results/gsm8k/llama-2-13b/bit4-rank64/lr1e-4 \
    --learning_rate 1e-4  \
    --weight_decay 0.1 \
    --lr_scheduler_type cosine \
    --num_warmup_steps 100 \
    --seed 202 \
    --dataset_name gsm8k \
    --dataset_config main \
    --pad_to_max_length \
    --max_source_length 128 \
    --max_target_length 256 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --with_tracking \
    --report_to tensorboard

Here are the results: epoch 0: accuracy: 0.26459438968915844 epoch 1: accuracy: 0.30932524639878695 epoch 2: accuracy: 0.35405610310841545 epoch 3: accuracy: 0.3434420015163002 epoch 4: accuracy: 0.3593631539044731

still far away from the reported results 45.0.

May I ask whether you use fake quantization or not for your reported results?

yxli2123 commented 6 months ago

Please pull the lasted head of this repo and use the training script in this repo: https://github.com/yxli2123/LoftQ/blob/main/scripts/train_gsm8k.sh#L22. We are still working on merging the up-to-date training files to PEFT.

We use the quantization equivalent weights do the training and testing. We have tested the results of quantization equivalent weights and bitsandbytes 4 bit weights are almost the same on GSM8K.

BaohaoLiao commented 6 months ago

Thank you for the update. I can reproduce the 4-bit result with Llama-7b