[2024-06-25] Support multi-GPUs inference with big LLMs now! Try out PyramidKV on LlaMa-3-70B-Instruct!
[2024-06-10] Support PyramidKV, SnapKV, H2O and StreamingLLM at Flash Attention v2, Sdpa Attention now! If your devices (i.e., V100, 3090) does not support Flash Attention v2, you can set attn_implementation=sdpa to try PyramidKV at Sdpa Attention!
[x] Support implementation of Streaming LLM, H2O and SnapKV
[x] Support Mistral model
[x] Support implementation of Needle
[x] Support KV cache compression without Flash Attention v2 (i.e. Sdpa Attention) for V100
[x] Support multi-GPU inference for 70B LlaMa-3
[ ] Introduce new functions to support kv cache budget allocation (i.e., supports for percentage.)
[ ] Support Mixtral
[ ] Support Batch Inference
[ ] Support KV cache compression at decoding stage
The Llama model attention map with 3 documents is represented as follows:
we provide a notebook visualization.ipynb
to reproduce the visualization result of each Llama-2-7b-hf model layer for a given 3 document.
Model attention maps for different layers would be stored at ./attention
transformers >= 4.41
flash-attn >= 2.4.0.post1
git clone https://github.com/Zefan-Cai/PyramidKV.git
cd PyramidKV
pip install -r requirements.txt .
We support inference code on LongBench
to repuduce our result.
Please refer to scripts/scripts_longBench/eval.sh
to modify the parameters according to your requirements.
Our codebase support Flash Attention v2, Sdpa Attention, etc. The results presented in our paper in based on Flash Attention v2.
export CUDA_VISIBLE_DEVICES=$1
method=$2 # Support PyramidKV, SnapKV, H2O, StreamingLLM
max_capacity_prompts=64 # 128,2048 in paper
attn_implementation=$3 # Support "flash_attention_2", "sdpa", "eager".
source_path=$4
model_path=$5
save_dir=${source_path}"results_long_bench" # path to result save_dir
python3 run_longbench.py \
--method ${method} \
--model_path ${model_path} \
--max_capacity_prompts ${max_capacity_prompts} \
--attn_implementation ${attn_implementation} \
--save_dir ${save_dir} \
--use_cache True
PyramidKV
, SnapKV
, StreamingLLM
, H2O
.After modifying parameters, run:
sh scripts/scripts_longBench/eval.sh
We support inference code on Needle in haystack
to repuduce our result.
Please refer to scripts/scripts_needle/eval.sh
to modify the parameters according to your requirements.
Our codebase support Flash Attention v2, Sdpa Attention, etc. The results presented in our paper in based on Flash Attention v2.
METHOD='pyramidkv' # ['full', 'pyramidkv', 'snapkv', 'streamingllm', 'h2o']
MAX_CAPACITY_PROMPT=96 # [64, 96, 128, 256, 512, 1024, 2048, ...]
attn_implementation="flash_attention_2" # Support "flash_attention_2", "sdpa", "".
TAG=test
# For Llama3-8b
(
python -u run_needle_in_haystack.py --s_len 1000 --e_len 8001\
--model_provider LLaMA3 \
--model_name /mnt/workspace/zhiyuanhu/yuliang/models/llama3-8b_raw \
--attn_implementation ${attn_implementation} \
--step 100 \
--method $METHOD \
--max_capacity_prompt $MAX_CAPACITY_PROMPT \
--model_version LlaMA3_${METHOD}_${MAX_CAPACITY_PROMPT}_${TAG}
) 2>&1 | tee results_needle/logs/LlaMA3_${METHOD}_${MAX_CAPACITY_PROMPT}_${TAG}.log
PyramidKV
, SnapKV
, StreamingLLM
, H2O
.To reproduce our results, run
bash scripts/scripts_needle/eval.sh
After inference, run
python scripts/scripts_needle/visualize.py
to draw the img, you should change FOLDER_PATH
in visualize.py
to your output path (the argument of --model_version
in eval.sh
).
If you find PyramidKV useful for your research and applications, please kindly cite using this BibTeX:
@misc{cai2024pyramidkv,
title={PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling},
author={Zefan Cai. and Yichi Zhang and Bofei Gao and Tianyu Liu and Keming Lu and Wayne Xiong and Yue Dong and Baobao Chang and Junjie Hu and Wen Xiao},
year={2024},
eprint={2406.02069},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Thanks [SnapKV] SnapKV: LLM Knows What You are Looking for Before Generation for providing open-source code to support the expansion of this project.