OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
663 stars 50 forks source link

Slow decoding compared to AWQ #19

Closed abhinavkulkarni closed 10 months ago

abhinavkulkarni commented 11 months ago

Hey @ChenMnZ,

Thanks for the great work. I was trying out AWQ (A16W4) and OmniQuant (A4W4) versions for model meta-llama/llama-2-70b-chat-hf and noticed that OmniQuant is much slower than AWQ. I used the following snippet of code to benchmark:

model.eval()
prompt = "Give me a list of the top 10 dive sites you would recommend around the world. \nThe list is:"
input_ids = enc(prompt, return_tensors='pt').input_ids.cuda()
model = model.cuda()
start_time = time.time()
output = model.generate(inputs=input_ids, do_sample=True, top_k=10, max_new_tokens=128)
end_time = time.time()
speed = len(output[0])/(end_time-start_time)
print(enc.decode(output[0]))
print(f"speed:{speed}token/s")

AWQ was able to generate around 36.13 token/s whereas OmniQuant could only generate 6.15 token/s. I ran these on RTX 3060 (12GiB of VRAM).

For AWQ, I ran the model abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq, instructions for how to run model are in the model card.

For OmniQuant, I downloaded Llama-2-7b-w4a4.pth and followed the notebook for Falcon (slightly modified it for Llama 2).

Thanks!

ChenMnZ commented 11 months ago

Yeah, I know that.This is because that the difference of CUDA kernel. AutoGPTQ kernel is slower than AWQ kernel. However, We choose the kernel of GPTQ due to it is compatible with various quantization bits and group size, while the kernel of AWQ onle support W4A16g128 quantization.

Nevertheless, OmniQuant don't introduce any additional parameters or operations for quantized model. So you can also introduce AWQ kernel into OmniQuant to obtain some speedup for W4A16g128 quantization. If you want do that, simply replace

q_linear = qlinear_cuda.QuantLinear(wbits, group_size, module.in_features,module.out_features,not module.bias is None,kernel_switch_threshold=128)

with AWQ linear layer.

abhinavkulkarni commented 11 months ago

Hey @ChenMnZ,

Thanks for the reply. Yes, this should work for W4A16 quantization.

I'm only interested in 4-bit quantization, so will this also work for W4A4 quants?

Thanks!

ChenMnZ commented 11 months ago

No, our study only test the speedup of weight-only quantization. Regarding weight-activation quantization W4A4, it lacks out-of-the-box hardware support, so our study does not incorporate its hardware implementation.

abhinavkulkarni commented 11 months ago

Thanks, so you mean to say someone needs to write optimized CUDA kernels for W4A4 quantization?

I'm particularly interested in activation quantization since for large context models (such as codellama/CodeLlama-7b-Instruct-hf with 16k context size, extendable up to 100k tokens), activations account for majority of the memory usage for longer input prompts.

With codellama models being increasingly used for variety of code intelligence projects, there is a need to be able to pass really long prompts to these models.

ChenMnZ commented 11 months ago

Yeah, W4A4 requires additional effort for the kernels design.

The point you mentioned is very interesting. Achieving really speedup for W4A4 quantization, or simply compress the KV cache, are on our future roadmap.

brisker commented 10 months ago

@ChenMnZ why are all activation quantization INT16, not INT8, in all experiments in OmniQuant ? It seems that w8a8 w4a8 settings are very common in other LLM-quantization papers (like SmoothQuant), but not in OmniQuant.

ChenMnZ commented 10 months ago

@brisker In our paper, we also take experiments on W4A4, and our code can also support a4a8 quantization. However, they are just with fake quantization and cannot obtain actual speedup due to the lack of out-of-the-box hardware support.