Outsider565 / LoRA-GA

142 stars 5 forks source link

Question about performance of gsm8k #13

Open LZY-the-boys opened 3 hours ago

LZY-the-boys commented 3 hours ago

Hi, thanks for your awesome work, I notice some problem about performance of gsm8k when reproducing:

First, I check that PISSA and LORAGA both use MetaMathQA dataset and eval on gsm8k, In PISSA paper, their gsm8k performance is 53.07,

image

However, In LORAGA paper, the gsm8k performance of PISSA is only 44.54, while the LORA performance in both paper is similar, 42.30 vs 42.08, I think it is very strange.

image

Second, I run LORAGA code in LoRA-GA/examples/float_llama2-7b_metamath.py and eval in both use vllm and huggingface generation code, I find that the result is only 42.3 vs 44.1, far away from 53.60 in the paper, I use the following script to train:

accelerate launch --main_process_port 20001 \
--config_file accelerate_config.yaml \
-m float_llama2-7b_metamath

and run evaluation by eval_gsm8k.py:

    _, _, test_set = load_gsm8k()
    model_type = "CausalLM"
    model, tokenizer = initialize_text_to_text_model(
        "meta-llama/Llama-2-7b-hf", model_type, 'bf16', tokenizer="meta-llama/Llama-2-7b-hf",flash_attention=True
    )
    model = PeftModel.from_pretrained(model,`LoRA-GA/examples/results/peft_test/model=llama_d=meta_math_a=8_r=32_s=128_sd=31`)
    model = model.to('cuda')

so if the released parameter is not the optimized version? Could you offer the original hyperparameter? thanks.

Outsider565 commented 2 hours ago
  1. The discrepancy between PiSSA and our results is attributed to different rank settings. PiSSA reports accuracies on various tasks in Table 1 using a rank of 128 (according to Figure 6(c) and Figure 14(b) in their paper, only with rank=128 can such high performance be achieved), while our accuracy (in Table 2) is based on a rank of 8 for all LoRA variants. When comparing accuracies at a rank of 8, the accuracy for PiSSA that we implemented on GSM8K is 44.54 (Table 2), which surpasses the reported accuracy of less than 40 (Figure 6(c) and Figure 14(b)) in PiSSA’s paper. Remarkably, when comparing the performance at rank 128 from our paper and PiSSA, LoRA-GA still outperforms PiSSA on GSM8K (55.07 vs. 53.07), and Human-eval (23.05 vs. 21.95). Furthermore, even LoRA-GA at rank 8 surpasses PiSSA at rank 128 on GSM8K (53.60 vs. 53.07).
  2. You can find out the hyperparameter in the appendix of the paper. (You could also try the legacy code in the reproduce folder, which should exactly reproduce the result in the paper. However, some cases of numerical issues on some hardware were reported.) If there's any problem, feel free to contact me again.