OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
626 stars 49 forks source link

reproduce evaluation results #60

Open oujieww opened 5 months ago

oujieww commented 5 months ago

i use the llama weight from "huggyllama/llama-7b" i want to reproduce the "LLaMA-7B W4A4" by script/llama/llama-7b/w4a4.sh

CUDA_VISIBLE_DEVICES=0 python main.py \ --model ${MODEL_PATH} --eval_ppl \ --epochs 20 --output_dir ./log/llama-7b-w4a4 \ --wbits 4 --abits 4 --lwc --let --aug_loss

but i got wikitext2 : 11.583588600158691; c4 : 14.935160636901855 but in your paper: wikitext2 :11.23(epoch-40) , 11.26(epoch-20); c4 : 14.61 from table A.4

i want to check is there any details i have missed?

i think if we use the same seed, we should get same result?

ChenMnZ commented 5 months ago

This discrepancy is caused by a fixed bug. Please refer https://github.com/OpenGVLab/OmniQuant/issues/12 and https://github.com/OpenGVLab/OmniQuant/pull/36.

oujieww commented 5 months ago

thank you, so i should follow this? : OmniQuant/quantize/omniquant.py

Lines 203 to 205 in 834847a

if args.epochs > 0: with torch.no_grad(): qlayer.float() # required for AMP training

to with torch.no_grad(): qlayer.float() if args.epochs > 0:

oujieww commented 5 months ago

@ChenMnZ hi, i have change the code as mentioned in https://github.com/OpenGVLab/OmniQuant/issues/12,as:

with torch.no_grad(): qlayer.float() if args.epochs > 0:

  1. use the saved parameters by my trained one: wiki: 11.583588600158691 ; paper: wikitext2 :11.23(epoch-40) , 11.26(epoch-20);

  2. use the parameters from your huggingface "llama-7b-w4a4.pth"

wiki: 11.614770889282227

so, is my base model not right(huggyllama/llama-7b)? or my code is still not right?

here is my test shell:

MODEL_PATH=/mnt/data/oujie/huggingface_cache/hub/models--huggyllama--llama-7b/snapshots/huggyllama_llama-7b CUDA_VISIBLE_DEVICES=0 python main.py \ --model ${MODEL_PATH} --eval_ppl \ --epochs 0 --output_dir ./log/llama-7b-w4a4 \ --wbits 4 --abits 4 --lwc --let --aug_loss \ --resume ./pretrained/llama-7b-w4a4.pth

ChenMnZ commented 5 months ago

Please refer https://github.com/OpenGVLab/OmniQuant/blob/main/scripts/Llama-2/Llama-2-7b/w4a4.sh for the training script, maybe you can set the --alpha as 0.75 to improve the performance.

ChenMnZ commented 5 months ago

Please refer https://github.com/OpenGVLab/OmniQuant/blob/main/scripts/Llama-2/Llama-2-7b/w4a4.sh for the training script, maybe you can set the --alpha as 0.75 to improve the performance.

Just give a try follow this in the latest code, I think you can reproduce the results.

oujieww commented 5 months ago

thank you, use the parameters from your huggingface "llama-7b-w4a4.pth", do you mean: MODEL_PATH=/mnt/data/oujie/huggingface_cache/hub/models--huggyllama--llama-7b/snapshots/huggyllama_llama-7b CUDA_VISIBLE_DEVICES=0 python main.py --model ${MODEL_PATH} --eval_ppl --epochs 0 --output_dir ./log/llama-7b-w4a4 --wbits 4 --abits 4 --lwc --let --aug_loss --alpha 0.75 --resume ./pretrained/llama-7b-w4a4.pth

oujieww commented 5 months ago

yes, thank you so much! i have download the latest code, for reproducing. i will try the LLaMA-2 follow your https://github.com/OpenGVLab/OmniQuant/blob/main/scripts/Llama-2/Llama-2-7b/w4a4.sh ; but i want to reproduce the LLaMA-1-7B and LLaMA-2-7B as in your paper, but now have some gap in result, i want to make sure if i miss some details.

ChenMnZ commented 5 months ago

The parameters from my huggingface have some problem (as mentioned in https://github.com/OpenGVLab/OmniQuant/issues/12#issuecomment-1798066350 and solved in https://github.com/OpenGVLab/OmniQuant/pull/36), resulting that resume can not reproduce the results.

So for reproducing the result, you can train the model by yourself with the scripts in https://github.com/OpenGVLab/OmniQuant/tree/main/scripts.

And thank you for mentioning this error, I will also re-train these models and update the huggingface for easier assessment.

.

oujieww commented 5 months ago

@ChenMnZ thank you, as i retrain follow your https://github.com/OpenGVLab/OmniQuant/tree/main/scripts/llama/llama-7b/w4a4.sh

i have change code as in #12 , and i have check the code is same as #36 ;

i can get : wiki: 11.583588600158691 (epoch-20); paper: wikitext2 :11.23(epoch-40) , 11.26(epoch-20);