Zefan-Cai / PyramidKV

MIT License
327 stars 15 forks source link

Pyramid KV


News

TODO:

Performence



Visualization: Inefficient Attention

The Llama model attention map with 3 documents is represented as follows:


we provide a notebook visualization.ipynb to reproduce the visualization result of each Llama-2-7b-hf model layer for a given 3 document.

Model attention maps for different layers would be stored at ./attention

Requirements

transformers >= 4.41
flash-attn >= 2.4.0.post1

Installation


git clone https://github.com/Zefan-Cai/PyramidKV.git
cd PyramidKV
pip install -r requirements.txt .

Inference

We support inference code on LongBench to repuduce our result.

Please refer to scripts/scripts_longBench/eval.sh to modify the parameters according to your requirements.

Our codebase support Flash Attention v2, Sdpa Attention, etc. The results presented in our paper in based on Flash Attention v2.

export CUDA_VISIBLE_DEVICES=$1

method=$2 # Support PyramidKV, SnapKV, H2O, StreamingLLM
max_capacity_prompts=64 # 128,2048 in paper
attn_implementation=$3 # Support "flash_attention_2", "sdpa", "eager".
source_path=$4
model_path=$5
save_dir=${source_path}"results_long_bench" # path to result save_dir

python3 run_longbench.py \
    --method ${method} \
    --model_path ${model_path} \
    --max_capacity_prompts ${max_capacity_prompts} \
    --attn_implementation ${attn_implementation} \
    --save_dir ${save_dir} \
    --use_cache True

After modifying parameters, run:


sh scripts/scripts_longBench/eval.sh

Needle in haystack

We support inference code on Needle in haystack to repuduce our result.

Please refer to scripts/scripts_needle/eval.sh to modify the parameters according to your requirements.

Our codebase support Flash Attention v2, Sdpa Attention, etc. The results presented in our paper in based on Flash Attention v2.


METHOD='pyramidkv'       # ['full', 'pyramidkv', 'snapkv', 'streamingllm', 'h2o']
MAX_CAPACITY_PROMPT=96  # [64, 96, 128, 256, 512, 1024, 2048, ...]
attn_implementation="flash_attention_2" # Support "flash_attention_2", "sdpa", "".
TAG=test

# For Llama3-8b

(
python -u run_needle_in_haystack.py --s_len 1000 --e_len 8001\
    --model_provider LLaMA3 \
    --model_name /mnt/workspace/zhiyuanhu/yuliang/models/llama3-8b_raw \
    --attn_implementation ${attn_implementation} \
    --step 100 \
    --method $METHOD \
    --max_capacity_prompt $MAX_CAPACITY_PROMPT \
    --model_version LlaMA3_${METHOD}_${MAX_CAPACITY_PROMPT}_${TAG}
) 2>&1  | tee results_needle/logs/LlaMA3_${METHOD}_${MAX_CAPACITY_PROMPT}_${TAG}.log

To reproduce our results, run

bash scripts/scripts_needle/eval.sh

After inference, run

python scripts/scripts_needle/visualize.py

to draw the img, you should change FOLDER_PATH in visualize.py to your output path (the argument of --model_version in eval.sh).

Citation

If you find PyramidKV useful for your research and applications, please kindly cite using this BibTeX:

@misc{cai2024pyramidkv,
      title={PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling}, 
      author={Zefan Cai. and Yichi Zhang and Bofei Gao and Tianyu Liu and Keming Lu and Wayne Xiong and Yue Dong and Baobao Chang and Junjie Hu and Wen Xiao},
      year={2024},
      eprint={2406.02069},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Acknowledgement

Thanks [SnapKV] SnapKV: LLM Knows What You are Looking for Before Generation for providing open-source code to support the expansion of this project.