This repository contains code and data for "Learning to Compress Prompts with Gist Tokens."
This codebase has been tested with python 3.9.16 and pytorch 2.0.0. I recommend creating a new virtual env (e.g. with Conda), then installing torch manually from pytorch.org
. pip install -r requirements.txt
should take care of the remaining dependencies.
[!IMPORTANT] This codebase requires a specific version of Transformers: commit fb366b9a. Installing from
requirements.txt
should install the correct version. Newer Transformers releases will not work, as some function signatures have changed inmodeling_llama.py
(see this issue).[!IMPORTANT] Training results are reproducible only with DeepSpeed version 0.8.3. For some (currently unknown) reason, newer DeepSpeed versions result in some performance degradation (see this issue).
By default, experiment runs and model checkpoints are saved to exp/
directory
in root directory, and cached models (downloaded from the Huggingface Hub) and
datasets are saved to .cache/
. Be sure to create these directories before
running for the first time.
I recommend either changing these directories in the config, or symlinking them to wherever you have plenty of space on your machine.
LLaMA-7B experiments expect a folder called llama-7b
in the root directory
with model weights and tokenizer.
If you'd like to train models (not just use them), set up a Weights & Biases account for experiment logging, and replace my username with yours in the wandb.entity
field of src/conf/config.yaml
here.
Checkpoints for the 1 token gist models for LLaMA-7B and FLAN-T5-XXL (as well as positive and negative controls) are now available on Hugging Face:
[!NOTE] The released LLaMA-7B checkpoints are weight diffs. You must have the base LLaMA-7B weights to recover the original model. Please use the
src/weight_diff.py
script to recover the original model given the weight diffs above, following the instructions in the Alpaca repository (but using my script instead). Alternatively, if you use thecompress.py
script below and specify one of the Hugging Face diffs, the weight diff will be automatically applied for you if you supply--base_llama_path
.
To use the model and try out gist caching, use the src/compress.py
script, e.g.
python -m src.compress --model_name_or_path jayelm/llama-7b-gist-1 --base_llama_path llama-7b \
--instruction "Name the top cities in France that should not be missed. Include the best aspects of each place as well."
Here, --instruction
is the prompt to be compressed and cached, and --input
is an (optional) input you can supply that is not compressed.
compress.py
is well documented; use the --help
flag for more details and browse the code to see how gist caching is done. If you're loading a FLAN-T5-XXL checkpoint, you do not need to supply --base_llama_path
.
[!NOTE] Gist compression is currently only supported for
batch_size = 1
. Larger batch sizes are mostly implemented in FLAN-T5-XXL, but I haven't checked correctness as carefully. For LLaMA-7B, larger batch sizes will require modifying the rotary position embedings to account for gist offsets here.
If you'd like to retrain the Gist models, the command
python -m src.train \
training.gist.num_gist_tokens=2 training.gist.condition=gist wandb.tag=yourtaghere
Trains a small model (FLAN-T5-base) on the Alpaca+ training dataset with 2 gist tokens and gist masking, while logging to wandb.
Change the number of gist tokens with num_gist_tokens
. condition
should be
set to gist
, pos_control
(no masking), or neg_control
(masking that simply
masks out the instruction entirely).
For debugging, you may be interested in setting the +experiment=debug
flag, which runs a small model (FLAN-T5-small) on a tiny number of samples and evaluations, just to check that the train/eval loop is working.
[!NOTE] If you're not familiar with the CLI syntax, check out Hydra.
[!NOTE] If you receive a
ConfigValueError
, see this issue for a workaround.[!NOTE] For VSCode users, some example launch configurations for debugging are in
.vscode/launch.json
.
To train the larger models in the paper (FLAN-T5-XXL, LLaMA-7B), multi-gpu
training is required with DeepSpeed. ./run_deepspeed.sh
contains an example, but the basic idea is:
deepspeed --num_gpus=4 --no_local_rank --module src.train \
+model={flan-t5-xxl,llama-7b} \
deepspeed=ds_configs/stage3.json \
training.gist.num_gist_tokens 2 \
training.gist.condition=gist
This trains either flan-t5-xxl
/llama-7b
with the same gist configuration as the
first flan-t5-base command above, using the hyperparameters in the paper. See
src/conf/{flan-t5-xxl,llama-7b}.yaml
for the hyperparameter configurations.
These experiments all assume a machine with 4 A100 80GB GPUs and at least 400GB of CPU RAM. Other machine configurations will necessitate changing the batch size and/or deepspeed config setting.
Take a look at other model configs in src/conf/model
. In particular there's a
llama-debug.yaml
file which trains a small randomly initialized LLaMA model
for debugging.
Be sure to set your wandb
entity name correctly in src/conf/config.yaml
, if it is not your username.
By default this logs an experiment to wandb under a group name that begins with wandb.tag
(i.e. in the example above, yourgroupname
); check out src/conf/config.yaml
to see the full group name. Metrics are also logged to stdout, but there's a lot of other noise in stdout/stderr.
The main metrics to pay attention to in the wandb interface are {seen,unseen,human}_{rouge1,rouge2,rougeL}
, which are the ROUGE scores for the seen/unseen/human splits, respectively.
The wandb group and run names define a directory which will save model checkpoints and outputs. By default it is exp/{wandb.group}/{wandb.run}
. For example, if you run with the +experiment=debug
setting, then the wandb group is set to debug-alpaca-plus
. Saving model checkpoints is disabled in the debug config, but model outputs are nevertheless saved to exp/debug-alpaca-plus/debug-alpaca-plus-run-42
. For example, exp/debug-alpaca-plus/debug-alpaca-plus-run-42/outputs-100-validation_seen.csv
contains model outputs (greedy decode) on the seen split after 100 steps of training. These are useful for manual inspection, and also are the input for ChatGPT evaluation (see below).
Use sbatch run_deepspeed.sh
to submit multi-GPU DeepSpeed jobs to a SLURM cluster, or just run in an interactive session.
If you are not using DeepSpeed, i.e. training smaller models (e.g. FLAN-T5-base, LLaMA-debug) on a single GPU, you can also use hydra-submitit-launcher
to conveniently transform a local run into a SLURM batch job submission. Do pip install hydra-submitit-launcher
, then specify +launcher=slurm
via CLI to send a job to slurm. Use of -m
or --multirun
as a Hydra option is required for the SLURM launcher to be picked up. Configure slurm parameters (e.g. partition, account, etc) in src/conf/launcher/slurm.yaml
.
This is particularly useful with hydra's sweep functionality. E.g. the command
python -m src.train -m +experiment=flan-t5-base wandb.tag=sweep-demo \
training.gist.condition=gist,pos_control,neg_control
submits an array of 3 jobs to slurm, sweeping across the gist conditions.
ROUGE results are logged automatically during training above, but ChatGPT evaluation results need to be done manually.
Obtain filepaths to the predictions from two models you'd like to compare. Outputs used for evaluation in the paper are in data/results/{FLAN-T5,LLaMA}-{gist,pos_control,neg_control}-{1,2,5,10}
, sweeping over the model, gist condition, and number of gist tokens respectively.
For example, say we want to compare the LLaMA single gist token model to its pos control. You can use the scripts/eval_chatgpt.py
script:
python scripts/eval_chatgpt.py \
--a data/results/model_outputs/LLaMA-gist-1/outputs-3000-validation_human.csv \
--a_name LLaMA-gist-1 \
--b data/results/model_outputs/LLaMA-pos_control-1/outputs-3000-validation_human.csv \
--b_name LLaMA-pos-control-1 \
--results_file my_comparison.json
You will need a valid OpenAI key---follow OpenAI API setup instructions.
Occasionally ChatGPT will spit out something that cannot be parsed by the JSON parser. In these cases it will log to stderr and the json for the result will have a "COULD NOT PARSE JSON" message. You can search for these messages in the results file, manually fix the responses, and change the overall scores accordingly.
The ChatGPT comparisons reported in the paper are located in data/results/chatgpt
.
The training script has benchmarking functionality which is used to obtain the benchmarking results.
Benchmarking was done without DeepSpeed on a single A100 80GB GPU, though a 40GB GPU is likely fine too. An example command for benchmarking is (also available as a VSCode launch config):
python -m src.train \
training.do_train=false \
training.do_eval=true \
training.do_benchmarking=true \
training.do_benchmarking_sanity_checks=true \
training.gist.num_gist_tokens=1 \
training.gist.condition=gist \
model.model_name_or_path=YOUR_PATH_TO_PRETRAINED_GIST_MODEL \
training.benchmarking_profiler=pytorch \
training.benchmarking_output_file=my_benchmarking.csv
Some notes here:
./zero_to_fp32.py . pytorch_model.bin
in the checkpoint you'd like to benchmark.training.benchmarking_profiler
. Paper uses PyTorch default profiler.do_benchmarking_sanity_checks=true
activates gist caching sanity checking, where we verify that model outputs and decodes are same with and without gist caching.[!NOTE] For the larger models, we actually found that we would often fail gist sanity checks due to floating point errors. If you run the larger models with sanity checking on, you will find some torch assertion errors where 99% of the model states are identical, except for one value here or there.
[!NOTE] We did not heavily optimize the gist caching implementation, so wall clock speedups (especially CPU times) are likely small or even non-existent due to the increased Python logic for gist caching. The main point of the gist caching implementation in this paper is to show it can be done and sanity check that the attention masking during training works for such caching behavior at inference time. The biggest gains from gist caching are likely to be achieved using custom, lower-level implementations of gist caching that optimize for inference latency.
Like the other sections, the benchmarking results used in the paper are available in the data/benchmarking
folder. See data/README.md
for more details.
The Alpaca+ data is located in data/alpaca_plus
. ChatGPT evaluations, raw model outputs, and benchmarking stats used for the paper are located in data/results
and data/benchmarking
.
The codebase is licensed Apache 2.0 (see LICENSE
). The data is a mixture of
Self-Instruct (Apache 2.0) and Stanford Alpaca (CC BY-NC 4.0). By training on a
mixture of the data you inherit both licenses.
To the Stanford Alpaca team for assistance with the Alpaca data and finetuning.
If you found this work useful, please cite
@article{mu2023learning,
title={Learning to Compress Prompts with Gist Tokens},
author={Jesse Mu and Xiang Lisa Li and Noah Goodman},
year={2023},
eprint={2304.08467},
archivePrefix={arXiv},
primaryClass={cs.CL}
}