OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
663 stars 50 forks source link

How to quantize a llama structure model and run it with sampling process? #11

Closed gesanqiu closed 11 months ago

gesanqiu commented 11 months ago

I trained some llama structure like models, and use following command quatize my model:

CUDA_VISIBLE_DEVICES=0 python main.py --model /workdir/hf_models/aquila-chat-7b/ --eval_ppl --epochs 20 --output_dir ./log/aquila-chat-7b-w4a16 --wbits 4 --abits 16 --lwc --net llama-7b

it failed with following outpus:

[2023-09-20 05:57:45 root](omniquant.py 233): INFO layer 25 iter 13 loss:0.8996272683143616 norm:0.0016859474126249552 max memory_allocated 14942.64794921875
[2023-09-20 05:58:02 root](omniquant.py 233): INFO layer 25 iter 14 loss:0.8993780016899109 norm:0.0017346511594951153 max memory_allocated 14942.64794921875
[2023-09-20 05:58:18 root](omniquant.py 233): INFO layer 25 iter 15 loss:0.899031937122345 norm:0.001769484020769596 max memory_allocated 14942.64794921875
[2023-09-20 05:58:35 root](omniquant.py 233): INFO layer 25 iter 16 loss:0.8987679481506348 norm:0.0017901259707286954 max memory_allocated 14942.64794921875
[2023-09-20 05:58:52 root](omniquant.py 233): INFO layer 25 iter 17 loss:0.8985449075698853 norm:0.0018109744414687157 max memory_allocated 14942.64794921875
[2023-09-20 05:59:08 root](omniquant.py 233): INFO layer 25 iter 18 loss:0.8984124660491943 norm:0.0018458825070410967 max memory_allocated 14942.64794921875
[2023-09-20 05:59:25 root](omniquant.py 233): INFO layer 25 iter 19 loss:0.8984500169754028 norm:0.0018331403844058514 max memory_allocated 14942.64794921875
[2023-09-20 05:59:28 root](omniquant.py 158): INFO === Start quantize layer 26 ===
[2023-09-20 05:59:37 root](omniquant.py 223): INFO Loss is NAN, stopping training
> /workdir/OmniQuant/quantize/omniquant.py(226)omniquant()
-> loss_list.append(loss.data)

In which case will produce a NULL Loss object? BTW, this command works on llama-7b well.

Besides PPL evaluation, I also need to do subject/object evaluaiton on some dataset. I know there is a runing_falcon180b_on_single_a100_80g.ipynb, which shown how to run a quantized Falcon-180b, but it seems I have to learn AutoGPTQ first before I can know how to load quantized weight of a model?

I convert the quantized llama-7b-w4a16.pth to hf format with https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py. I tried replaced the FalconLinear to nn.Linear , not sure if it's correct. And I wander what the group size is, if I didn't use group size in quantize phase, what value should I set in QuantLinear?

Are there any examples or docs can help me run a llama model quickly?

ChenMnZ commented 11 months ago

If you want to run your own really quantized model, you should first save the model through --save_dir and --real_quant. For example:

CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/LLaMA/llama-7b  \
--epochs 0 --output_dir temp \
--eval_ppl --wbits 3 --abits 16 --group_size 128 --lwc \
--resume PATH/TO/OMNI_PARAMETERS
--real_quant --save_dir PATH/TO/SAVE/MODEL

Then, you can run the quantized model by replacing the following variates in runing_falcon180b_on_single_a100_80g.ipynb with the corresponding values:

#model_path = './pre_quantized_models/falcon-180b-omniquant-w3a16g512'
#wbits = 3
#group_size = 512
model_path = PATH/TO/SAVE/MODEL
wbits = 3
group_size = 128

You can set group_size as -1 to disable group quantization, and if you want to test in more datasets, maybe you can refer to the --tasks. For examle:

CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/LLaMA/llama-7b  \
--epochs 0 --output_dir temp \
--eval_ppl --wbits 3 --abits 16 --group_size 128 --lwc \
--tasks piqa,arc_easy,arc_challenge,boolq,hellaswag,winogrande

As for NAN, it may be caused by mixed-precision training, you can use --deactive_amp to use full-precision training.

gesanqiu commented 11 months ago

Thanks for your detailed reply, I found that the NAN is my model's problem so I fallback to official llama-2-7b to explore OmniQuant.

I tried to run a quantized llama-2-7b-chat-w4a16g128 model with runing_falcon180b_on_single_a100_80g.ipynb but it failed with following logs:

Traceback (most recent call last):
  File "/workdir/test/test_omniquant.py", line 35, in <module>
    layers = model.transformer.h
  File "/root/anaconda3/envs/omniquant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'LlamaForCausalLM' object has no attribute 'transformer'

Except this problem, I think LlamaForCausalLM also doesn't include FalconLinear instance, so it may failed in get_named_linears(layer)?

gesanqiu commented 11 months ago

I replaced the code like follow and it worked:

from torch import nn

def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}

# layers = model.transformer.h
layers = model.model.layers