studio-ousia / bpr

Binary Passage Retriever (BPR) - an efficient passage retriever for open-domain question answering
Other
168 stars 11 forks source link

Binary Passage Retriever

Binary Passage Retriever (BPR) is an efficient neural retrieval model for open-domain question answering. BPR integrates a learning-to-hash technique into Dense Passage Retriever (DPR) to represent the passage embeddings using compact binary codes rather than continuous vectors. It substantially reduces the memory size without a loss of accuracy tested on Natural Questions and TriviaQA datasets.

BPR was originally developed to improve the computational efficiency of the Sōseki question answering system submitted to the Systems under 6GB track in the NeurIPS 2020 EfficientQA competition. Please refer to our ACL 2021 paper for further technical details.

News

June 6, 2021: New Model Trained Using Natural Questions with Improved Negative Examples

We conducted a follow-up experiment using the new Natural Questions training data with improved negative examples. BPR generally achieved similar performance to DPR with substantially reduced index size. The results of HNSW and HNSW-SQ were obtained using the DPR code (8bacf08).

Model Faiss index type Index size Top 1 Top 20 Top 100 Query time
New BPR Binary flat 2GB 49.0 80.5 87.0 85 ms
New BPR Binary hash 2GB 49.0 80.5 87.0 38 ms
New DPR Flat 61GB 52.5 81.3 87.3 457ms
New DPR HNSW 141GB 52.2 81.0 86.8 2ms
New DPR HNSW-SQ 36GB 52.0 80.9 86.9 2ms

The fine-tuned checkpoint and index files are provided below.

Installation

BPR can be installed using Poetry:

poetry install

The virtual environment automatically created by Poetry can be activated by poetry shell.

Alternatively, you can install required libraries using pip:

pip install -r requirements.txt

Trained Models

BPR fine-tuned on the Natural Questions dataset:

BPR fine-tuned on the Natural Questions dataset with improved negative examples:

BPR fine-tuned on the TriviaQA dataset:

Example Usage

>>> import faiss
>>> from bpr import BiEncoder, FaissBinaryIndex, InMemoryPassageDB, Retriever
# Load the model from the checkpoint file
>>> biencoder = BiEncoder.load_from_checkpoint("bpr_finetuned_nq.ckpt")
>>> biencoder.eval()
>>> biencoder.freeze()
# Load Wikipedia passages into memory
>>> passage_db = InMemoryPassageDB("psgs_w100.tsv")
# Load the index
>>> base_index = faiss.read_index_binary("bpr_finetuned_nq.idx")
>>> index = FaissBinaryIndex(base_index)
# Instantiate the Retriever
>>> retriever = Retriever(index, biencoder, passage_db)
# Encode queries into embeddings
>>> query_embeddings = retriever.encode_queries(["what is the tallest mountain in the world"])
# Get top-100 results
>>> retriever.search(query_embeddings, k=100)[0][0]
Candidate(id=525407, score=93.59397888183594, passage=Passage(id=525407, title='Mount Everest', text="Mount Everest Mount Everest, known in Nepali as Sagarmatha () and in Tibetan as Chomolungma (), is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas. The international border between Nepal (Province No. 1) and China (Tibet Autonomous Region) runs across its summit point. The current official elevation of , recognized by China and Nepal, was established by a 1955 Indian survey and subsequently confirmed by a Chinese survey in 1975. In 2005, China remeasured the rock height of the mountain, with a result of 8844.43 m. There followed an argument between China and"))

The Wikipedia passage data (psgs_w100.tsv) is available on the DPR website. At the time of writing, the file can be downloaded by cloning the DPR repository and running the following command:

python data/download_data.py --resource data.wikipedia_split.psgs_w100

Reproducing Experiments

Before you start, you need to download the datasets available on the DPR website into <DPR_DATASET_DIR>.

The experimental results on the Natural Questions dataset can be reproduced by running the commands provided in this section. We used a server with 8 NVIDIA Tesla V100 GPUs with 16GB memory in the experiments. The results on the TriviaQA dataset can be reproduced by changing the file names of the input dataset to the corresponding ones (e.g., nq-train.json -> trivia-train.json).

1. Building passage database

python build_passage_db.py \
    --passage_file=<DPR_DATASET_DIR>/wikipedia_split/psgs_w100.tsv \
    --output_file=<PASSAGE_DB_FILE>

2. Training BPR

python train_biencoder.py \
   --gpus=8 \
   --distributed_backend=ddp \
   --train_file=<DPR_DATASET_DIR>/retriever/nq-train.json \
   --eval_file=<DPR_DATASET_DIR>/retriever/nq-dev.json \
   --gradient_clip_val=2.0 \
   --max_epochs=40 \
   --binary

3. Building passage embeddings

python generate_embeddings.py \
   --biencoder_file=<BPR_CHECKPOINT_FILE> \
   --output_file=<EMBEDDING_FILE> \
   --passage_db_file=<PASSAGE_DB_FILE> \
   --batch_size=4096 \
   --parallel

4. Evaluating BPR

python evaluate_retriever.py \
    --binary_k=1000 \
    --biencoder_file=<BPR_CHECKPOINT_FILE> \
    --embedding_file=<EMBEDDING_FILE> \
    --passage_db_file=<PASSAGE_DB_FILE> \
    --qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
    --parallel

5. Creating dataset for reader

python evaluate_retriever.py \
    --binary_k=1000 \
    --biencoder_file=<BPR_CHECKPOINT_FILE> \
    --embedding_file=<EMBEDDING_FILE> \
    --passage_db_file=<PASSAGE_DB_FILE> \
    --qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-train.csv \
    --output_file=<READER_TRAIN_FILE> \
    --top_k=200 \
    --parallel

python evaluate_retriever.py \
    --binary_k=1000 \
    --biencoder_file=<BPR_CHECKPOINT_FILE> \
    --embedding_file=<EMBEDDING_FILE> \
    --passage_db_file=<PASSAGE_DB_FILE> \
    --qa_file=<DPR_DATASET_DIR>/retriever/qas/nq-dev.csv \
    --output_file=<READER_DEV_FILE> \
    --top_k=200 \
    --parallel

python evaluate_retriever.py \
    --binary_k=1000 \
    --biencoder_file=<BPR_CHECKPOINT_FILE> \
    --embedding_file=<EMBEDDING_FILE> \
    --passage_db_file=<PASSAGE_DB_FILE> \
    --qa_file==<DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
    --output_file=<READER_TEST_FILE> \
    --top_k=200 \
    --parallel

6. Training reader

python train_reader.py \
   --gpus=8 \
   --distributed_backend=ddp \
   --train_file=<READER_TRAIN_FILE> \
   --validation_file=<READER_DEV_FILE> \
   --test_file=<READER_TEST_FILE> \
   --learning_rate=2e-5 \
   --max_epochs=20 \
   --accumulate_grad_batches=4 \
   --nq_gold_train_file=<DPR_DATASET_DIR>/gold_passages_info/nq_train.json \
   --nq_gold_validation_file=<DPR_DATASET_DIR>/gold_passages_info/nq_dev.json \
   --nq_gold_test_file=<DPR_DATASET_DIR>/gold_passages_info/nq_test.json \
   --train_batch_size=1 \
   --eval_batch_size=2 \
   --gradient_clip_val=2.0

7. Evaluating reader

python evaluate_reader.py \
    --gpus=8 \
    --distributed_backend=ddp \
    --checkpoint_file=<READER_CHECKPOINT_FILE> \
    --eval_batch_size=1

License

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Citation

If you find this work useful, please cite the following paper:

Efficient Passage Retrieval with Hashing for Open-domain Question Answering

@inproceedings{yamada2021bpr,
  title={Efficient Passage Retrieval with Hashing for Open-domain Question Answering},
  author={Ikuya Yamada and Akari Asai and Hannaneh Hajishirzi},
  booktitle={ACL},
  year={2021}
}