OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
663 stars 50 forks source link

aug_loss option in OmniQuant Scripts #12

Closed MarsJacobs closed 11 months ago

MarsJacobs commented 11 months ago

Hi! Thanks for the awesome quantization work! I learned a lot. I have a question regarding the aug_loss option. I noticed in some scripts with llama, this option is present.

https://github.com/OpenGVLab/OmniQuant/blob/d797d51e7acae9804a33d8e16a3232982e433ed3/quantize/omniquant.py#L220-L221

I'm curious to know what this additional loss specifically does, and the reasons behind its inclusion. I'd also like to understand why it's only included in certain models and with specific bit precisions.

ChenMnZ commented 11 months ago

To compute the MSE loss, one must feed the input data into both the quantized model block and the full-precision model block. The key of understanding the aug_loss mechanism pertains to how the input data is sourced.

In the original loss, the output from the full-precision block serves as the input for the subsequent computations (refer codes or pseudocode in paper Appendix for details ). In such a situation, full-precision block and quantized block have different inputs.

As for the augmented loss (--aug_loss), the output from the quantized model is used as the input for the full-precision block. This ensures that both the quantized and full-precision blocks have identical input.

Through our evaluations, it has been observed that --aug_loss significantly enhances the performance of the LLaMa family, particularly evident in the lower-bit (2) quantization. However, for higher-bit (4) quantization, the benefits offered by --aug_loss are minimal and come at the cost of increased VRAM usage. As such, we recommend utilizing --aug_loss exclusively for scenarios involving lower-bit quantization.

I hope I can solve your question.

MarsJacobs commented 11 months ago

Thank you for the detailed response. I understand now. close this issue

brisker commented 10 months ago

@ChenMnZ Is there any bug on w4a8 quantization setting for the codes? python main.py --model ./Llama-2-7b --epochs 20 --output_dir ./log/debug --eval_ppl --wbits 4 --abits 8 --let is giving something weird like:

[2023-11-06 20:43:14 root](omniquant.py 141): INFO === Start quantize layer 0 ===
[2023-11-06 20:43:38 root](omniquant.py 216): INFO layer 0 iter 0 loss:6.076880708860699e-06 norm:nan max memory_allocated 15862.96337890625
[2023-11-06 20:43:58 root](omniquant.py 216): INFO layer 0 iter 1 loss:6.076880708860699e-06 norm:nan max memory_allocated 15862.96337890625
[2023-11-06 20:44:19 root](omniquant.py 216): INFO layer 0 iter 2 loss:6.076880708860699e-06 norm:nan max memory_allocated 15862.96337890625
[2023-11-06 20:44:39 root](omniquant.py 216): INFO layer 0 iter 3 loss:6.076880708860699e-06 norm:nan max memory_allocated 15862.96337890625

without changing any codes, python main.py --model ./Llama-2-7b --epochs 20 --output_dir ./log/debug --eval_ppl --wbits 4 --abits 4 --let is giving reasonable logs:

[2023-11-06 20:45:47 root](omniquant.py 141): INFO === Start quantize layer 0 ===
[2023-11-06 20:46:11 root](omniquant.py 216): INFO layer 0 iter 0 loss:4.7241879656212404e-05 norm:6.737659714417532e-05 max memory_allocated 15863.83056640625
[2023-11-06 20:46:31 root](omniquant.py 216): INFO layer 0 iter 1 loss:3.508200461510569e-05 norm:4.012686986243352e-05 max memory_allocated 15863.83056640625
ChenMnZ commented 10 months ago

You should leverage --deactive_amp to execute 8-bit quantization. Please refer #8 for more details. Moreover, activate both --let and --lwc may boost the performance.

brisker commented 10 months ago

@ChenMnZ python main.py --model /Llama-2-7b --epochs 20 --output_dir ./log/Llama2-7b-848-learning --eval_ppl --wbits 4 --abits 8 --let --lwc --deactive_amp runs for several hours and gives me the following PPL results:

[2023-11-07 16:07:20 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_wikitext2_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 166/166 [01:08<00:00,  2.44it/s]
[2023-11-07 16:08:28 root](main.py 146): INFO wikitext2 : 5.8246283531188965
[2023-11-07 16:08:28 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_ptb_all.cache
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:20<00:00,  2.44it/s]
[2023-11-07 16:08:48 root](main.py 146): INFO ptb : 34.367431640625
[2023-11-07 16:08:48 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_c4_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [01:44<00:00,  2.44it/s]
[2023-11-07 16:10:33 root](main.py 146): INFO c4 : 7.511241912841797

BUT when I reload the omni_parameters.pth file generate by the training process above, like this:

python main.py --model /LLM/Llama-2-7b --epochs 0 --output_dir ./log/debug --eval_ppl --wbits 4 --abits 8 --lwc --let --resume ./log/Llama2-7b-848-learning/omni_parameters.pth

it givens me the PPL results different from the above(5.82 vs 5.83 for wiki, 34.36 vs 39.34 for PTB), which is really weird(especially for the PTB datasets)

[2023-11-07 16:17:52 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_wikitext2_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 166/166 [02:13<00:00,  1.24it/s]
[2023-11-07 16:20:06 root](main.py 146): INFO wikitext2 : 5.8349175453186035
[2023-11-07 16:20:06 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_ptb_all.cache
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:39<00:00,  1.24it/s]
[2023-11-07 16:20:45 root](main.py 146): INFO ptb : 39.341251373291016
[2023-11-07 16:20:45 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_c4_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [03:26<00:00,  1.24it/s]
[2023-11-07 16:24:12 root](main.py 146): INFO c4 : 7.514362812042236

Besides, even if I use the resume parameters from https://huggingface.co/ChenMnZ/OmniQuant/tree/main, like this:

python main.py --model /LLM/Llama-2-7b --epochs 0 --output_dir ./log/debug --eval_ppl --wbits 4 --abits 4 --lwc --let --resume ./from_paper/Llama-2-7b-w4a4.pth it givens the w4a4 results similar to your paper, but still slightly different (14.37 vs 14.26 for wiki, 18.21 vs 18.02 for C4, and very bad result for PTB, which is not shown in the paper):

[2023-11-07 16:30:14 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_wikitext2_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 166/166 [02:12<00:00,  1.25it/s]
[2023-11-07 16:32:27 root](main.py 146): INFO wikitext2 : 14.372145652770996
[2023-11-07 16:32:27 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_ptb_all.cache
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:39<00:00,  1.25it/s]
[2023-11-07 16:33:06 root](main.py 146): INFO ptb : 1405.301025390625
[2023-11-07 16:33:06 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_c4_all.cache
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [03:24<00:00,  1.25it/s]
[2023-11-07 16:36:31 root](main.py 146): INFO c4 : 18.21759033203125

image

Is there any bug on the codes or something wrong when I got these reuslts?

ChenMnZ commented 10 months ago

@brisker Thanks for your proposal, I will check this bug.

brisker commented 9 months ago

@ChenMnZ Are the potential bug founded, or they are not real bugs?

ChenMnZ commented 9 months ago

@brisker I think this is caused by the precision transition during training. Can you retest your checkpoints after modifying https://github.com/OpenGVLab/OmniQuant/blob/834847adcee9575b89cd14ed2a3623c770743b4a/quantize/omniquant.py#L203-L205 to

    with torch.no_grad(): 
        qlayer.float() 
    if args.epochs > 0: 
brisker commented 9 months ago

@ChenMnZ different from neither above, but all different: (for PTB, 34, 37, 39, three PPL results...) python main.py --model /LLM/Llama-2-7b --epochs 0 --output_dir ./log/Llama2-7b-848-learning/resume_test_modify_code --eval_ppl --wbits 4 --abits 8 --lwc --let --resume ./log/Llama2-7b-848-learning/omni_parameters.pth

[2023-11-17 10:06:12 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_wikitext2_all.cache
100%|████████████████████████████████████████████████████████████████████████████████████████| 166/166 [02:14<00:00,  1.24it/s]
[2023-11-17 10:08:26 root](main.py 146): INFO wikitext2 : 5.829050540924072
[2023-11-17 10:08:26 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_ptb_all.cache
100%|██████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:39<00:00,  1.23it/s]
[2023-11-17 10:09:06 root](main.py 146): INFO ptb : 37.58609390258789
[2023-11-17 10:09:06 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_c4_all.cache
100%|████████████████████████████████████████████████████████████████████████████████████████| 256/256 [03:27<00:00,  1.23it/s]
[2023-11-17 10:12:34 root](main.py 146): INFO c4 : 7.512151718139648
brisker commented 9 months ago

@ChenMnZ this seems to be truly a bug?

ChenMnZ commented 9 months ago

@brisker Yes, this is a bug. However, I am too busy recently and have to left it to later.

brisker commented 9 months ago

@ChenMnZ by the way, you set qlayer as float32 here : https://github.com/OpenGVLab/OmniQuant/blob/main/quantize/omniquant.py#L205

Is the reason behind this that, trainable parameters have to be float32 in mixed precision training? or if this one is also fp16, the accuracy will be influenced?

ChenMnZ commented 9 months ago

@brisker Yes, you are right. AMP require float32 or bfloat16 parameters. I have also tried bloat16, but I failed to make it work. So, I choose float32 finally.

brisker commented 9 months ago

@ChenMnZ I have made a PR, you can check it.

@ChenMnZ this seems to be truly a bug?

@ChenMnZ different from neither above, but all different: (for PTB, 34, 37, 39, three PPL results...) python main.py --model /LLM/Llama-2-7b --epochs 0 --output_dir ./log/Llama2-7b-848-learning/resume_test_modify_code --eval_ppl --wbits 4 --abits 8 --lwc --let --resume ./log/Llama2-7b-848-learning/omni_parameters.pth

[2023-11-17 10:06:12 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_wikitext2_all.cache
100%|████████████████████████████████████████████████████████████████████████████████████████| 166/166 [02:14<00:00,  1.24it/s]
[2023-11-17 10:08:26 root](main.py 146): INFO wikitext2 : 5.829050540924072
[2023-11-17 10:08:26 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_ptb_all.cache
100%|██████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:39<00:00,  1.23it/s]
[2023-11-17 10:09:06 root](main.py 146): INFO ptb : 37.58609390258789
[2023-11-17 10:09:06 root](main.py 102): INFO load calibration from ./cache/testloader_Llama_c4_all.cache
100%|████████████████████████████████████████████████████████████████████████████████████████| 256/256 [03:27<00:00,  1.23it/s]
[2023-11-17 10:12:34 root](main.py 146): INFO c4 : 7.512151718139648
ChenMnZ commented 9 months ago

@brisker Thanks for your contribution!