OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
626 stars 49 forks source link

AutoGPTQ or AutoGPTQ-bugfix? #57

Open Alvant opened 6 months ago

Alvant commented 6 months ago

Some time ago, in README there was a link to the "fixed version" of AutoGPTQ: AutoGPTQ-bugfix. However, current README gives link to the original repo: AutoGPTQ.

So, does this mean that everything is OK with AutoGPTQ real quantization now and we do not need the fixed repo?

I am asking such question, because, for example, the fix for qlinear triton was the following (link1, link2):

# qlinear_triton.py
# ...

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

# zeros -= 1  # This line removed in the fix
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0

# ...

However, in AutoGPTQ there is still such zeros modification (link). So, it seems that original AutoGPTQ still might have some problems?..

Alvant commented 6 months ago

Actually, I also tried to quantize some models with AutoGPTQ from https://github.com/AutoGPTQ/AutoGPTQ, and it seemed like the quality was worse.

Alvant commented 5 months ago

Wherever the PPL quality is better with AutoGPTQ-bugfix or not, the following is worth noting. If saving checkpoint with AutoGPTQ-bugfix, then the model will not work properly with vLLM, because their GPTQ kernels seem to make use of this "zeros +- 1" trick: https://github.com/vllm-project/vllm/blob/main/csrc/quantization/gptq/q_gemm.cu#L172-L175.

ChenMnZ commented 5 months ago

AutoGPTQ-bugfix is ok. Sorry for previous confusion, the official AutoGPTQ repo have merged the "zeros +- 1" solution before. However, the solution was reverted due to some incompatibility, please refer https://github.com/AutoGPTQ/AutoGPTQ/pull/354 for more details.

Alvant commented 4 months ago

@ChenMnZ Ok! Thank you.

Well, the picture became clearer but still not quite :sweat_smile: What is better: fixed version or not? Why there was even a need for such fix?

As far as I understood, the thing is the following. AutoGPTQ assumes that quantization is symmetric. However, it may be not so (OmniQuant uses asymmetric quantization by default). What is more, there is no way to tell AutoGPTQ's QuantLinear that quantization is not symmetric.

All public GPTQ-packed real quantized models that I met are symmetrically quantized (for example, Llama-2-13B by TheBloke).

Personally, I tested OmniQuant on Phi2 model. Symmetric quantization resulted in good quality GPTQ model, whereas asymmetric one led to broken GPTQ real-quant model.

So, seems that this "AutoGPTQ or AutoGPTQ-bugfix" question may be more of a "symmetric or asymmetric quantization" question. At least, seems that the real-quant (GPTQ) backend may be not always compatible with the way a model is actually quantized.

P.S. Sorry for my late response :sweat_smile:

lqzzy commented 4 months ago

Wherever the PPL quality is better with AutoGPTQ-bugfix or not, the following is worth noting. If saving checkpoint with AutoGPTQ-bugfix, then the model will not work properly with vLLM, because their GPTQ kernels seem to make use of this "zeros +- 1" trick: https://github.com/vllm-project/vllm/blob/main/csrc/quantization/gptq/q_gemm.cu#L172-L175.

Could you please tell me a way to make the model produced by running ./scripts/Llama-2/Llama-2-7b/w4a4.sh work properly with vLLM? I want to accelerate the speed of inference by combining W4A4 with vllm.

lqzzy commented 4 months ago

@Alvant @ChenMnZ

Alvant commented 3 months ago

@lqzzy Hello! I am afraid I can't help you with this :sweat_smile: Personally, I used vLLM only with W4A16 quantized models (and there were no problems with these ones, vLLM can handle it). I guess, GPTQ does not quantize activations at all. For example, in OmniQuant code, there is also no possibility to obtain a real-quantized weights and activations GPTQ model (OmniQuant utilizes GPTQ for real quantization).

GPTQ reduces model size, vLLM boosts models (it accelerates inference of even FP precision models). So maybe there is no need to quantize activations if you use vLLM?) However, if you really want to use OmniQuant W4A4 models, as far as I understand, there is no easy way to make these ones compatible with vLLM...