Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using 🤗 transformers, aiming to simplify the development of new methods for researchers and developers in this field.
pip install kvpress
We recommend using flash attention if possible:
pip install flash-attn --no-build-isolation
This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a compression_ratio
parameter that controls the amount of pruning. The easiest way to use a press is through our custom KVPressTextGenerationPipeline
that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
from kvpress import ExpectedAttentionPress
from transformers import pipeline
device = "cuda:0"
model= "microsoft/Phi-3.5-mini-instruct"
pipe = pipeline("kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":"flash_attention_2"})
context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context" # optional
press = ExpectedAttentionPress(compression_ratio=0.4)
answer = pipe(context, question=question, press=press)["answer"]
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the Wikipedia notebook demo for a more detailed example.
[!IMPORTANT]
We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.[!NOTE]
To use theObservedAttentionPress
, usemodel_kwargs={"attn_implementation":"eager"}
in order to materialize the attention weights (this method is not compatible with flash attention).
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the FAQ for more information on how presses work and how to create new ones or check the new_press.ipynb notebook for a step-by-step guide.
All current presses are training free. We provide the following presses associated with the following scores:
RandomPress
: random scoreKnormPress
: inverse norm of the key (paper)ObservedAttentionPress
: average attention weight observed during in pre-filling phase (similar to H2O or TOVA)SnapKVPress
: average attention weight of the last 64 queries (paper)ExpectedAttentionPress
(ours): expected attention weight during the generation phase (see this notebook)StreamingLLMPress
: keep only the first and last tokens (paper)For a detailed list of existing KV cache compression methods, check Awesome-KV-Cache-Compression or Awesome-LLM-Compression
We provide a simple CLI to evaluate the performance of the different presses on several long-context datasets.
Average performance on the RULER dataset with 4k context length and Loogle Short Dependency QA task for 3 models and 7 presses
Please refer to the evaluation directory for more details and results.
We support KV cache quantization through the transformers QuantizedCache
class (see HF blog post). To use it, simply pass a cache object to your pipeline:
from transformers import QuantizedCacheConfig, QuantoQuantizedCache
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
pipe(..., cache=cache)
By default, the DynamicCache
is used (no quantization).
[!IMPORTANT]
To use theQuantizedCache
, you need to install additional dependencies (e.g.pip install optimum-quanto==0.2.4
, see also this issue).