chu-tianxiang / exl2-for-all

EXL2 quantization generalized to other models.
8 stars 2 forks source link

EXL2 for all

EXL2 is a mixed-bits quantization method proposed in exllama v2. This repo is created from exllamav2 with support for more model architectures. Unlike repos like AutoAWQ and AutoGPTQ which include various kernel fusions, this repo only contains minimal code for quantization and inference.

Installation

exllama v2 kernels have to installed first. See requirements.txt for dependencies.

Examples

exllamav2 changed the optimization algorithm in v0.0.11. This repo by default will use the new algorithm, if you want to use the old one, please pass version="v1" to Exl2Quantizer.

exllamav2 by default use standard_cal_data which is a mix of c4, code, wiki and so on. To be consistent with other quantization method, we use redpajama dataset instead.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from exl2 import Exl2Quantizer

model_name = "meta-llama/Llama-2-7b-hf"
quant_dir = "llama-exl2-4bits"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

quant = Exl2Quantizer(bits=4.0, dataset="redpajama")
quant_model = quant.quantize_model(model, tokenizer)
quant.save(quant_model, quant_dir)
tokenizer.save_pretrained(quant_dir)
import torch
from transformers import AutoTokenizer
from model import load_quantized_model

quant_model = load_quantized_model("turboderp/Llama2-7B-exl2", revision="2.5bpw")
tokenizer = AutoTokenizer.from_pretrained("turboderp/Llama2-7B-exl2", revision="2.5bpw")
input_ids = tokenizer.encode("The capital of France is", return_tensors="pt").cuda()
output_ids = quant_model.generate(input_ids, do_sample=True)
print(tokenizer.decode(output_ids[0]))

An additional parameter is modules_to_not_convert because Mixtral gate layer is often unquantized.

quant_model = Exl2ForCausalLM.from_quantized("turboderp/Mixtral-8x7B-instruct-exl2",
                                             revision="3.0bpw",
                                             modules_to_not_convert=["gate"])

Perplexity

LLaMA-2 7b on wikitext. bpw perplexity
FP16 6.23
2.5 10.13
3.0 7.25
3.5 6.88
4.0 6.40
4.5 6.37