facebookresearch / dpr-scale

Scalable training for dense retrieval models.
262 stars 25 forks source link

KeyError: 'positive_ctxs' when running run_retrieval.py for nq-test.jsonl #1

Closed mei16 closed 1 year ago

mei16 commented 2 years ago

When I run run_retrieval.py with nq-test.jsonl as the test file, I got KeyError: 'positive_ctxs' as the nq-test.jsonl does not have positive_ctxs. Why we need positive_ctxs in the test?

Traceback (most recent call last): File "/home/default/persistent_drive/dpr_scale/dpr_scale/run_retrieval.py", line 83, in main trainer.test(task, datamodule=datamodule) File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 914, in test results = self.test_given_model(model, test_dataloaders) File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 972, in test_given_model results = self.fit(model) File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 498, in fit self.dispatch() File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 539, in dispatch self.accelerator.start_testing(self) File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 76, in start_testing self.training_type_plugin.start_testing(trainer) File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 118, in start_testing self._results = trainer.run_test() File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 785, in run_test eval_loopresults, = self.run_evaluation() File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 711, in run_evaluation for batch_idx, batch in enumerate(dataloader): File "/usr/local/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in next data = self._next_data() File "/usr/local/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/usr/local/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch return self.collate_fn(data) File "/home/default/persistent_drive/dpr_scale/dpr_scale/datamodule/dpr.py", line 138, in collate_test return self.collate(batch, "test") File "/home/default/persistent_drive/dpr_scale/dpr_scale/datamodule/dpr.py", line 203, in collate return self.dpr_transform(batch, stage) File "/usr/local/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/default/persistent_drive/dpr_scale/dpr_scale/transforms/dpr_transform.py", line 85, in forward contexts_pos = row["positive_ctxs"] KeyError: 'positive_ctxs'

ccsasuke commented 2 years ago

Our DataLoader expects input in the DPR format (https://github.com/facebookresearch/DPR#retriever-input-data-format), where the positive_ctxs key is required for each query. If your data doesn't have labels, adding a dummy label to the input file should work, though you won't be able to evaluate the model accuracy of course.

mei16 commented 2 years ago

Since I want to compare the performance with DPR, I need to get the top-K accuracy on nq-test.jsonl with 3610 questions. According to eval_dpr.py, the postive_ctxs are not used to evaluate the retrieval accuracy. So adding dummy positive_ctxs won't affect the retrieval accuracy, am I right?

ccsasuke commented 2 years ago

Yes, I believe that's correct.

mei16 commented 2 years ago

For nq-test.jsonl, with dummy positive_ctxs and hard_negative_ctxs, I got 3610/3610 Top1 accuracy: 0.9520775623268698 Top5 accuracy: 0.9939058171745152 Top10 accuracy: 0.9991689750692521 Top20 accuracy: 1.0 Top50 accuracy: 1.0 Top100 accuracy: 1.0

While using the original DPR, I got: Top1 accuracy: 0.45401662049861496 Top5 accuracy: 0.6703601108033241 Top10 accuracy: 0.739612188365651 Top20 accuracy: 0.7878116343490305 Top50 accuracy: 0.8340720221606648 Top100 accuracy: 0.8587257617728532

It seems the new accuracy is too high.

ccsasuke commented 2 years ago

Hi @mei16, we never run into this problem. Did you make sure there was no error in the nq-test.jsonl file?

Here's a snippet of the nq-test.jsonl file we used: (As I mentioned, we used dummy positive_ctxs)

{"question": "who got the first nobel prize in physics", "answers": ["Wilhelm Conrad Röntgen"], "positive_ctxs": [], "hard_negative_ctxs": []}
{"question": "when is the next deadpool movie being released", "answers": ["May 18 , 2018"], "positive_ctxs": [], "hard_negative_ctxs": []}
{"question": "which mode is used for short wave broadcast service", "answers": ["Olivia", "MFSK"], "positive_ctxs": [], "hard_negative_ctxs": []}
... ...
mei16 commented 2 years ago

Yes. I used the same format for the nq-test.jsonl as in your snippet. Btw, I downloaded the paq_bert_base checkpoint, and then fine-tuned on nq-train.jsonl, following on your instructions in the section: Fine-tuning DPR on downstream tasks/datasets

DPR_ROOT=/home/xxxx/persistent_drive/dpr_scale NAME=test LR=1e-5 BSZ=16 MODEL="bert-base-uncased" MAX_EPOCHS=40 WARMUP_STEPS=1000 NODES=1 PRETRAINED_CKPT_PATH=/home/xxxx/persistent_drive/dpr_scale_model/paq_bert_base.ckpt EXP_DIR=/home/yunzhongliu/ephemeral_drive/dpr_scale PYTHONPATH=. python dpr_scale/main.py -m \ --config-dir ${DPR_ROOT}/dpr_scale/conf \ --config-name nq.yaml \ hydra.sweep.dir=${EXP_DIR} \ trainer.num_nodes=${NODES} \ trainer.max_epochs=${MAX_EPOCHS} \ datamodule.num_negative=1 \ datamodule.num_val_negative=25 \ datamodule.num_test_negative=50 \ +trainer.val_check_interval=150 \ task.warmup_steps=${WARMUP_STEPS} \ task.optim.lr=${LR} \ task.pretrained_checkpoint_path=$PRETRAINED_CKPT_PATH \ task.model.model_path=${MODEL} \ datamodule.batch_size=${BSZ} > ${EXP_DIR}/logs/log.out 2> ${EXP_DIR}/logs/log.err &

In nq.yaml:

datamodule: train_path: /home/xxxx/persistent_drive/DPR/downloads/data/retriever/nq-train.jsonl val_path: /home/xxxx/persistent_drive/DPR/downloads/data/retriever/nq-dev.jsonl test_path: /home/xxxx/persistent_drive/DPR/downloads/data/retriever/nq-dev.jsonl