snu-mllab / Context-Memory

Pytorch implementation for "Compressed Context Memory For Online Language Model Interaction" (ICLR'24)
https://arxiv.org/abs/2312.03414
MIT License
47 stars 1 forks source link
context-compression efficient-llm-inference kv-cache-compression

Compressed Context Memory

main

Paper | arXiv | Project Page

Main features of our method:

Setup

conda create --name ccm python=3.9
conda activate ccm
pip install -r requirements.txt

Supported Models: LLaMA / LLaMA-2-chat / Mistral

[!IMPORTANT]

  • In ./path_config.py, please set directory configurations.
  • To use LLaMA, please convert the LLaMA weights into Hugging Face Transformers format using the guideline.
  • [Update 24.02.21] We support Mistral models! To use the model, please upgrade pip install transformers==4.37.2 accelerate==0.27.2
  • You can train and test models by using --model [llama-7b,llama-2-7b-chat, mistral-7b-inst] flags.

We release datasets and models via gdown (see below).

[!TIP]

  • When gdown incurs errors, please directly download files from dataset link and model link (put model subfolders in SAVEPATH and dataset subfolders in DATAPATH from path_config.py).

Demo: Interactive inference with compressed memory

python download.py --type model --name [unified,pretrain]  # Download adapters
python inference.py -i -m [llama-7b,llama-2-7b-chat] --eval_name concat_recur

Streaming setting

Dataset

Training

[!Important]

  • Our experiments basically run on a single A100 80GB within 5~24h. In the case of DailyDialog, which has a smaller context length, we can run on a single RTX 3090 GPU with 24GB memory.
  • Set up a Wandb account for logging, and replace the username with yours in the wandb.entity field of src/conf/config.yaml.

Step 1 (optional): Fintuning LLaMA. We recommend first finetuning the LLaMA pretrained models on a dataset:

python run.py --train --dataset [unified,metaicl,dialog,lamp] --model llama-7b \
    --comp_type no

Step 2: Training a compression adapter.

python run.py --train --dataset [unified,metaicl,dialog,lamp] --model llama-7b \
    --load_path llama-7b-no \ 
    --attn_type [concat_recur,merge_recur] --n_tok [# <COMP> tokens]

Evaluation

Reference

Citation

@inproceedings{
      kim2024compressed,
      title={Compressed Context Memory for Online Language Model Interaction},
      author={Jang-Hyun Kim and Junyoung Yeom and Sangdoo Yun and Hyun Oh Song},
      booktitle={ICLR},
      year={2024},
}