transformers
libraryThis is a Hugging Face's 🤗 transformers
implementation of k-nearest-neighbor-based language models and machine translation models,
designed to be easy and useful in research, and for experimenting with new ideas in kNN-based models.
All previous kNN-LM based implementations are implemented in the fairseq
library, and they forked/duplicated the library's entire codebase to implement their modification.
These include the official kNN-LM repository https://github.com/urvashik/knnlm, the kNN-MT repository https://github.com/urvashik/knnmt, the official RetoMaton repository https://github.com/neulab/retomaton, and others.
We implement k-nearest-neighbor language model (kNN-LM) (Khandelwal et al., ICLR'2020), k-nearest-neighbor machine translation (kNN-MT) (Khandelwal et al., ICLR'2021) and this is also
an official implementation of the RetoMaton model described in:
Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval (ICML'2022). Most importantly, we implement these models in 🤗 transformers
, and without modifying the transformers
library itself.
To use this repository, all you need to do is copy its *.py
files into your project.
You can load any language model from Hugging Face's hub (such as gpt2
), by model = AutoModelForCausalLM.from_pretrained(...)
, build a datastore or download ours (need to be performed only once), and then:
knn_wrapper = KNNWrapper(...) # or: RetomatonWrapper(...)
knn_wrapper.break_into(model)
That's it! The model now internally uses kNN-LM or RetoMaton (see a concrete example at run_clm.py)
The files knnlm.py
and retomaton.py
are standalone and can be copied to any project. The file run_clm.py
is a modified version of this example by huggingface which shows an example of how to load and run kNN-LM and RetoMaton.
The file run_translation.py
is a modified version of this translation example by huggingface, which shows how to use our code for kNN-MT and RetoMaton.
This repository is maintained by Uri Alon. Please let us know if anything is not working as expected, and feel free to create new issues or email ualon@cs.cmu.edu with any questions. Contributions are welcome!
The k-nearest neighbor language model takes an already-trained model, performs a single forward pass over the entire training set, and creates a datastore of (key,value)
pairs, where key
is a hidden representation of the trained model after reading a training example, and value
is the token that should be predicted next.
At test time, for every predicted token, the model performs a k-nearest neighbors search in the datastore, retrieves the (key,value)
pairs that are closest to the test hidden representation, and normalizes their distances using softmax. Finally, the model interpolates the base LM's probability with the probability formed by the retrieved nearest neighbors and their normalized distances.
For more details, see the paper by Khandelwal et al., ICLR'2020
kNN-MT is similar to kNN-LM, but it addresses seq2seq/encoder-decoder models, mainly for machine translation. For more details, see the paper by Khandelwal et al., ICLR'2021
RetoMaton extends kNN-LM and kNN-MT, by (1) saving a pointer in every datastore entry; and (2) clustering entries according to their keys. That is, every datastore entry is now a tuple (key,value,pointer)
, and it belongs to a cluster.
These two changes create an automaton from the datastore, where states are clusters, edges are pointers (shared among examples in the same cluster), and transition weights are the normalized distances between the test representation and each key.
At test time, the model traverses the automaton, and follows the edges according to the token that was predicted. This allows to save up to 80% of the kNN searches by following pointers instead of performing the expensive search, or reducing perplexity without saving searches.
For more details, see the paper by Alon et al., ICML'2022
kNN-LM and RetoMaton datastores depend on the LM that was used to create them. We fine-tuned a few gpt2
-based models on the training set of Wikitext-103 (because Wikitext-103 was not included in GPT2's pretraining data):
neulab/distilgpt2-finetuned-wikitext103
neulab/gpt2-finetuned-wikitext103
neulab/gpt2-med-finetuned-wikitext103
neulab/gpt2-large-finetuned-wikitext103
All these models are available at the Hugging Face Hub and can be loaded by (for example):
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('neulab/gpt2-finetuned-wikitext103')
model = AutoModelForCausalLM.from_pretrained('neulab/gpt2-finetuned-wikitext103')
This project is not limited to these models, and can work with any language model or seq2seq model.
We fine-tuned all language models using:
python run_clm.py --model_name_or_path <base_model_name> \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--do_train --do_eval --output_dir finetune_gpt2_wikitext103/ \
--save_total_limit 5 --per_device_train_batch_size 2
Where <base_model_name>
is, for example, gpt2
, distilgpt2
, gpt2-med
, gpt2-large
, or gpt2-xl
.
We have not yet released finetuned machine translation models, but the code in this repository works for machine translation as well, using the run_translation.py
script.
The exact results from the RetoMaton papers can be reproduced using the code at https://github.com/neulab/retomaton (based on fairseq
).
The following results were obtained using the code in this repository:
Base LM: | distilgpt2 |
gpt2 |
gpt2-medium |
---|---|---|---|
base perplexity | 18.25 | 14.84 | 11.55 |
kNN-LM | 15.03 | 12.57 | 10.59 |
RetoMaton | 14.71 | 12.44 | 10.59 |
And when varying the fraction of saved searches:
These are the results from the RetoMaton paper, on a model that was trained on Wikitext-103 from scratch:
On the validation set of --dataset_name wmt16 --dataset_config_name ro-en
.
Base model: | t5-small |
t5-base |
---|---|---|
base BLEU | 26.15 | 27.70 |
+ kNN-MT | 26.42 | 27.92 |
--knn_temp=50 --k=32 --lmbda=0.25 |
--knn_temp=200 --k=512 --lmbda=0.2 |
If you perform additional experiments with our code, we would love to learn more about your results and share them here!
git clone https://github.com/neulab/knn-transformers
cd knn-transformers
Run:
pip install requirements.txt`
faiss
library. In MacOS, use the Anaconda installation instead:
conda install -c conda-forge faiss-cpu
To evaluate the fine-tuned model (for example, neulab/gpt2-finetuned-wikitext103
) on the validation set (without any retrieval):
MODEL=neulab/gpt2-finetuned-wikitext103
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--output_dir checkpoints/${MODEL} \
--do_eval --eval_subset validation
You can either download our preprocessed Wikitext-103 datastores, or preprocess them yourself.
To download a datastore for Wikitext-103 that we created for the finetuned gpt2
model (neulab/gpt2-finetuned-wikitext103
):
wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/gpt2/dstore_gpt2_116988150_768_vals.npy
Similarly, we created datastores using the distilgpt2-finetuned-wikitext103
, gpt2-med-finetuned-wikitext103
and gpt2-large-finetuned-wikitext103
.
For all available datastores, see: https://knn-transformers.s3.amazonaws.com/index.html
To save a datastore, run:
MODEL=neulab/gpt2-finetuned-wikitext103
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--do_eval --eval_subset train \
--output_dir checkpoints/${MODEL} \
--dstore_dir checkpoints/${MODEL} \
--save_knnlm_dstore
The FAISS index requires a training phase where it learns an index for accessing the keys quickly. This step does not require a GPU.
To download an index for the finetuned gpt2
model (neulab/gpt2-finetuned-wikitext103
):
wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/gpt2/index_gpt2_116988150_768.indexed
Similarly, we trained faiss
indexes for distilgpt2-finetuned-wikitext103
, gpt2-med-finetuned-wikitext103
and gpt2-large-finetuned-wikitext103
, see: https://knn-transformers.s3.amazonaws.com/index.html
To build the FAISS index yourself (not needed if you already downloaded ours):
MODEL=neulab/gpt2-finetuned-wikitext103
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--output_dir checkpoints/${MODEL} \
--dstore_dir checkpoints/${MODEL} \
--build_index
To evaluate kNN-LM and RetoMaton on the validation set:
MODEL=neulab/gpt2-finetuned-wikitext103
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--output_dir checkpoints/${MODEL} \
--do_eval --eval_subset validation \
--dstore_dir checkpoints/${MODEL} --retomaton
To use kNN-LM, use the --knn
flag instead of --retomaton
.
To encourage the RetoMaton model to perform a full kNN search more frequently and thus increase accuracy and reduce perplexity, use a larger value of --min-knns
such as 100
. Using --min-knns 9999999
makes the model perform kNN search at every step (FoSS = 0
in Figure 3 of the paper), and achieves the best results at the cost of slower speed.
Additional possible test-time tunable hyperparameters are --lmbda
(the interpolation factor between the datastore and the base LM), --k
(the number of retrieved nearest neighbors) and --knn_temp
(the softmax temperature when converting the nearest-neighbor distances into a probability distribution).
RetoMaton can work without clusters, in which is utilizes its pointers only. Using clustering allows it to save more nearest-neighbor searches and further reduce perplexity.
To download our processed clusters the finetuned gpt2
model (neulab/gpt2-finetuned-wikitext103
):
wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/index.htmlgpt2/members_gpt2_116988150_768_20000000_500000.pkl
Similarly, we also provide clusters for the distilgpt2
model (neulab/distilgpt2-finetuned-wikitext103
) at https://knn-transformers.s3.amazonaws.com/index.html.
To cluster similar keys for RetoMaton yourself:
MODEL=neulab/gpt2-finetuned-wikitext103
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--output_dir checkpoints/${MODEL} \
--dstore_dir checkpoints/${MODEL}/ \
--cluster_dstore --num_clusters 500000 --sample_size 20000000
Once the clustering file is saved in the directory pointed to by --dstore_dir
, it will automatically be picked up when running evaluation (as in the previous step)
Optional clustering hyperparameters are --num_clusters
(typically 1/100
or 1/200
of the datastore size) and --sample_size
(ideally as high as possible, but higher values consume more memory and take longer to run).
Using our code for machine translation and kNN-MT is very similar to language modeling, using the file run_translation.py
instead of run_clm.py
, and following the example instructions from huggingface: https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation.
Importantly, the --knn_temp
flag should be used and tuned for kNN-MT. As shown in the kNN-MT paper, the optimal temperature for kNN-MT can be 10
to 100
.
The --lmbda
interpolation factor is also typically larger in kNN-MT, and can be 0.4
-0.8
.
MODEL=t5-small
python -u run_translation.py \
--model_name_or_path ${MODEL} \
--dataset_name wmt16 --dataset_config_name ro-en \
--per_device_eval_batch_size=4 \
--output_dir checkpoints-translation/${MODEL} \
--source_lang en --target_lang ro \
--do_eval \
--predict_with_generate \
--source_prefix "translate English to Romanian: "
Note that the flag --source_prefix "translate English to Romanian: "
is slightly different for every model and task, and may be unneeded for other models, as detailed at https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation.
Examples datastores for t5-small
and t5-base
on wmt16 en-ro
are available at https://knn-transformers.s3.amazonaws.com/index.html.
MODEL=t5-small
python -u run_translation.py \
--model_name_or_path ${MODEL} \
--dataset_name wmt16 --dataset_config_name ro-en \
--per_device_train_batch_size 4 --per_device_eval_batch_size=4 \
--output_dir checkpoints-translation/${MODEL} \
--source_lang en --target_lang ro \
--dstore_size 26565876 \
--dstore_dir checkpoints-translation/${MODEL} \
--save_knnlm_dstore --do_eval --eval_subset train \
--source_prefix "translate English to Romanian: "
MODEL=t5-small
python -u run_translation.py \
--model_name_or_path ${MODEL} \
--dataset_name wmt16 --dataset_config_name ro-en \
--per_device_train_batch_size 4 --per_device_eval_batch_size=4 \
--output_dir checkpoints-translation/${MODEL} \
--source_lang en --target_lang ro \
--dstore_size 26565876 \
--dstore_dir checkpoints-translation/${MODEL} \
--build_index
MODEL=t5-small
python -u run_translation.py \
--model_name_or_path ${MODEL} \
--dataset_name wmt16 --dataset_config_name ro-en \
--per_device_eval_batch_size=4 \
--output_dir checkpoints-translation/${MODEL} \
--source_lang en --target_lang ro \
--do_eval \
--predict_with_generate \
--source_prefix "translate English to Romanian: " \
--dstore_size 26565876 \
--dstore_dir checkpoints-translation/${MODEL} \
--knn_temp 50 --k 32 --lmbda 0.25 \
--retomaton
To use kNN-MT, use the --knn
flag instead of --retomaton
.
Datastores and indexes can be downloaded from https://knn-transformers.s3.amazonaws.com/index.html
All fine-tuned models are available on Hugging Face Hub: https://huggingface.co/neulab
faiss
CPU to perform retrieval. However, we added the flag --knnlm-gpu
that allows performing retrieval much faster on the GPU.faiss
when performing search, or reconstructing the vectors from their index in RetoMaton, without loading the huge keys.npy
file into memory.If you use our code for research, please cite:
Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval
@inproceedings{alon2022neuro,
title={Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval},
author={Alon, Uri and Xu, Frank and He, Junxian and Sengupta, Sudipta and Roth, Dan and Neubig, Graham},
booktitle={International Conference on Machine Learning},
pages={468--485},
year={2022},
organization={PMLR}
}
This repository also implements: Generalization through Memorization: Nearest Neighbor Language Models
@inproceedings{khandelwal20generalization,
title={{Generalization through Memorization: Nearest Neighbor Language Models}},
author={Khandelwal, Urvashi and Levy, Omer and Jurafsky, Dan and Zettlemoyer, Luke and Lewis, Mike},
booktitle={International Conference on Learning Representations (ICLR)},
year={2020}
}