Binary Passage Retriever (BPR) is an efficient neural retrieval model for open-domain question answering. BPR integrates a learning-to-hash technique into Dense Passage Retriever (DPR) to represent the passage embeddings using compact binary codes rather than continuous vectors. It substantially reduces the memory size without a loss of accuracy tested on Natural Questions and TriviaQA datasets.
BPR was originally developed to improve the computational efficiency of the Sōseki question answering system submitted to the Systems under 6GB track in the NeurIPS 2020 EfficientQA competition. Please refer to our ACL 2021 paper for further technical details.
We conducted a follow-up experiment using the new Natural Questions training data with improved negative examples. BPR generally achieved similar performance to DPR with substantially reduced index size. The results of HNSW and HNSW-SQ were obtained using the DPR code (8bacf08).
Model | Faiss index type | Index size | Top 1 | Top 20 | Top 100 | Query time |
---|---|---|---|---|---|---|
New BPR | Binary flat | 2GB | 49.0 | 80.5 | 87.0 | 85 ms |
New BPR | Binary hash | 2GB | 49.0 | 80.5 | 87.0 | 38 ms |
New DPR | Flat | 61GB | 52.5 | 81.3 | 87.3 | 457ms |
New DPR | HNSW | 141GB | 52.2 | 81.0 | 86.8 | 2ms |
New DPR | HNSW-SQ | 36GB | 52.0 | 80.9 | 86.9 | 2ms |
The fine-tuned checkpoint and index files are provided below.
BPR can be installed using Poetry:
poetry install
The virtual environment automatically created by Poetry can be activated by
poetry shell
.
Alternatively, you can install required libraries using pip:
pip install -r requirements.txt
BPR fine-tuned on the Natural Questions dataset:
BPR fine-tuned on the Natural Questions dataset with improved negative examples:
BPR fine-tuned on the TriviaQA dataset:
>>> import faiss
>>> from bpr import BiEncoder, FaissBinaryIndex, InMemoryPassageDB, Retriever
# Load the model from the checkpoint file
>>> biencoder = BiEncoder.load_from_checkpoint("bpr_finetuned_nq.ckpt")
>>> biencoder.eval()
>>> biencoder.freeze()
# Load Wikipedia passages into memory
>>> passage_db = InMemoryPassageDB("psgs_w100.tsv")
# Load the index
>>> base_index = faiss.read_index_binary("bpr_finetuned_nq.idx")
>>> index = FaissBinaryIndex(base_index)
# Instantiate the Retriever
>>> retriever = Retriever(index, biencoder, passage_db)
# Encode queries into embeddings
>>> query_embeddings = retriever.encode_queries(["what is the tallest mountain in the world"])
# Get top-100 results
>>> retriever.search(query_embeddings, k=100)[0][0]
Candidate(id=525407, score=93.59397888183594, passage=Passage(id=525407, title='Mount Everest', text="Mount Everest Mount Everest, known in Nepali as Sagarmatha () and in Tibetan as Chomolungma (), is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas. The international border between Nepal (Province No. 1) and China (Tibet Autonomous Region) runs across its summit point. The current official elevation of , recognized by China and Nepal, was established by a 1955 Indian survey and subsequently confirmed by a Chinese survey in 1975. In 2005, China remeasured the rock height of the mountain, with a result of 8844.43 m. There followed an argument between China and"))
The Wikipedia passage data (psgs_w100.tsv
) is available on the
DPR website. At the time of writing,
the file can be downloaded by cloning the DPR repository and running the
following command:
python data/download_data.py --resource data.wikipedia_split.psgs_w100
Before you start, you need to download the datasets available on the
DPR website into <DPR_DATASET_DIR>
.
The experimental results on the Natural Questions dataset can be reproduced by
running the commands provided in this section. We used a server with 8 NVIDIA
Tesla V100 GPUs with 16GB memory in the experiments. The results on the TriviaQA
dataset can be reproduced by changing the file names of the input dataset to the
corresponding ones (e.g., nq-train.json
-> trivia-train.json
).
1. Building passage database
python build_passage_db.py \
--passage_file=<DPR_DATASET_DIR>/wikipedia_split/psgs_w100.tsv \
--output_file=<PASSAGE_DB_FILE>
2. Training BPR
python train_biencoder.py \
--gpus=8 \
--distributed_backend=ddp \
--train_file=<DPR_DATASET_DIR>/retriever/nq-train.json \
--eval_file=<DPR_DATASET_DIR>/retriever/nq-dev.json \
--gradient_clip_val=2.0 \
--max_epochs=40 \
--binary
3. Building passage embeddings
python generate_embeddings.py \
--biencoder_file=<BPR_CHECKPOINT_FILE> \
--output_file=<EMBEDDING_FILE> \
--passage_db_file=<PASSAGE_DB_FILE> \
--batch_size=4096 \
--parallel
4. Evaluating BPR
python evaluate_retriever.py \
--binary_k=1000 \
--biencoder_file=<BPR_CHECKPOINT_FILE> \
--embedding_file=<EMBEDDING_FILE> \
--passage_db_file=<PASSAGE_DB_FILE> \
--qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
--parallel
5. Creating dataset for reader
python evaluate_retriever.py \
--binary_k=1000 \
--biencoder_file=<BPR_CHECKPOINT_FILE> \
--embedding_file=<EMBEDDING_FILE> \
--passage_db_file=<PASSAGE_DB_FILE> \
--qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-train.csv \
--output_file=<READER_TRAIN_FILE> \
--top_k=200 \
--parallel
python evaluate_retriever.py \
--binary_k=1000 \
--biencoder_file=<BPR_CHECKPOINT_FILE> \
--embedding_file=<EMBEDDING_FILE> \
--passage_db_file=<PASSAGE_DB_FILE> \
--qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-dev.csv \
--output_file=<READER_DEV_FILE> \
--top_k=200 \
--parallel
python evaluate_retriever.py \
--binary_k=1000 \
--biencoder_file=<BPR_CHECKPOINT_FILE> \
--embedding_file=<EMBEDDING_FILE> \
--passage_db_file=<PASSAGE_DB_FILE> \
--qa_file==<DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
--output_file=<READER_TEST_FILE> \
--top_k=200 \
--parallel
6. Training reader
python train_reader.py \
--gpus=8 \
--distributed_backend=ddp \
--train_file=<READER_TRAIN_FILE> \
--validation_file=<READER_DEV_FILE> \
--test_file=<READER_TEST_FILE> \
--learning_rate=2e-5 \
--max_epochs=20 \
--accumulate_grad_batches=4 \
--nq_gold_train_file=<DPR_DATASET_DIR>/gold_passages_info/nq_train.json \
--nq_gold_validation_file=<DPR_DATASET_DIR>/gold_passages_info/nq_dev.json \
--nq_gold_test_file=<DPR_DATASET_DIR>/gold_passages_info/nq_test.json \
--train_batch_size=1 \
--eval_batch_size=2 \
--gradient_clip_val=2.0
7. Evaluating reader
python evaluate_reader.py \
--gpus=8 \
--distributed_backend=ddp \
--checkpoint_file=<READER_CHECKPOINT_FILE> \
--eval_batch_size=1
This
work is licensed under a
Creative
Commons Attribution-NonCommercial 4.0 International License.
If you find this work useful, please cite the following paper:
Efficient Passage Retrieval with Hashing for Open-domain Question Answering
@inproceedings{yamada2021bpr,
title={Efficient Passage Retrieval with Hashing for Open-domain Question Answering},
author={Ikuya Yamada and Akari Asai and Hannaneh Hajishirzi},
booktitle={ACL},
year={2021}
}