jy-yuan / KIVI

KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache
https://arxiv.org/abs/2402.02750
MIT License
200 stars 16 forks source link

run example.py with llama2-7B-hf only save 500MB kv cache memory conpared to base transformers ? #17

Open riou-chen opened 3 months ago

riou-chen commented 3 months ago

I run the example.py with llama2-7B-hf,set input length 4096 tokens,and output length 100 tokens. config.k_bits = 2, config.v_bits = 2. the kv cache occupy 5.6GB memory,only save about 500MB compared to base transformers. If k and v bits = 2, the kv cache should occupy less 1GB, but not ,why? And the inference speed not improved.

zirui-ray-liu commented 3 months ago

Thank you for your interests! Below I provide a few of my analysis. Hope it helps:

First, under 4096 context length, for llama-2-7b,

KV Cache size = BS (1, I assume) context_len (4096+100) 2 (K&V) num_layers (32) num_heads (32) hidden_size (128) num_bytes (2 for fp16) / 1024 ** 3 = 2.05 GB. Did you set larger batch size to get 5.6 GB memory?

Second, the example.py is the 5-shot GSM8K example. The input length is decided by the number of tokens inside the few-shot example. That said, the actual length cannot be easily controlled. Did you double check the input length fed to the model?

Third, what is your group size and residual length? If the default one is used, then under this 4096 length setting, the compression ratio is around 5X~6X.

Fourth, regarding the speedup. TL;DR, because your KV Cache is small and our current implmentation have not been optimized for the small KV Cache setting. If you want to have decent speed up, just enlarge your batch size and sequence length. For more details, please check our reply.

Hope it helps! Let me know more details regarding your experiments.

zirui-ray-liu commented 2 months ago

@riou-chen

Thank you for the patient.We just releaased a new branch develop, where we extensively optimize the codebase. I will write a new blog about the detailed optimization.

Now, since we rewrite the low level CUDA kernel, to use our new implmenetation, it requires you to rebuild the CUDA implementation through:

git check -b develop
git pull
cd quant && pip install -e .

Currently it only support Llama model. I have tested the new implementation with Llama-7B-hf on Longbench and the accuracy looks good.

Let me know if you have any problem with it!