facebookresearch / DPR

Dense Passage Retriever - is a set of tools and models for open domain Q&A task.
Other
1.69k stars 298 forks source link

Question on retriever acc for NQ dataset #96

Closed jzhoubu closed 3 years ago

jzhoubu commented 3 years ago

Hi, thanks for sharing the work. I tried to reproduce the retriever performance on NQ dataset (single dataset training), and below is my training command and result:

(base) bash-4.2$ CUDA_VISIBLE_DEVICES=6 bash train_dense_encoder.sh |& tee train_dense_encoder.log
### This is train_dense_encoder.sh ###
python train_dense_encoder.py \
    --encoder_model_type hf_bert \
    --pretrained_model_cfg bert-base-uncased \
    --train_file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-train.json \
    --dev_file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json \
    --output_dir /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq \
### This is part of train_dense_encoder.log ###
Initialized host lccpu16.cse.ust.hk as d.rank -1 on device=cuda, n_gpu=1, world size=1
16-bits training: False 
 **************** CONFIGURATION **************** 
adam_betas                     -->   (0.9, 0.999)
adam_eps                       -->   1e-08
batch_size                     -->   2
checkpoint_file_name           -->   dpr_biencoder
dev_batch_size                 -->   4
dev_file                       -->   /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
device                         -->   cuda
distributed_world_size         -->   1
do_lower_case                  -->   False
dropout                        -->   0.1
encoder_model_type             -->   hf_bert
eval_per_epoch                 -->   1
fix_ctx_encoder                -->   False
fp16                           -->   False
fp16_opt_level                 -->   O1
global_loss_buf_sz             -->   150000
gradient_accumulation_steps    -->   1
hard_negatives                 -->   1
learning_rate                  -->   1e-05
local_rank                     -->   -1
log_batch_step                 -->   100
max_grad_norm                  -->   1.0
model_file                     -->   None
n_gpu                          -->   1
no_cuda                        -->   False
num_train_epochs               -->   3.0
other_negatives                -->   0
output_dir                     -->   /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq
pretrained_file                -->   None
pretrained_model_cfg           -->   bert-base-uncased
projection_dim                 -->   0
seed                           -->   0
sequence_length                -->   512
shuffle_positive_ctx           -->   False
train_file                     -->   /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-train.json
train_files_upsample_rates     -->   None
train_rolling_loss_step        -->   100
val_av_rank_bsz                -->   128
val_av_rank_hard_neg           -->   30
val_av_rank_max_qs             -->   10000
val_av_rank_other_neg          -->   30
val_av_rank_start_epoch        -->   10000
warmup_steps                   -->   100
weight_decay                   -->   0.0
 **************** CONFIGURATION **************** 
***** Initializing components for training *****
Checkpoint files []
PyTorch version 1.3.0 available.
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /homes/jzhoubu/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /homes/jzhoubu/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
All model checkpoint weights were used when initializing HFBertEncoder.

All the weights of HFBertEncoder were initialized from the model checkpoint at bert-base-uncased.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use HFBertEncoder for predictions without further training.
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /homes/jzhoubu/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /homes/jzhoubu/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
All model checkpoint weights were used when initializing HFBertEncoder.

All the weights of HFBertEncoder were initialized from the model checkpoint at bert-base-uncased.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use HFBertEncoder for predictions without further training.
loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /homes/jzhoubu/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-train.json
Aggregated data size: 58880
Total cleaned data size: 58880
  Total iterations per epoch=29440
 Total updates=88320
  Eval step = 29440
***** Training *****
***** Epoch 0 *****
Epoch: 0: Step: 1/29440, loss=41.719326, lr=0.000000
Train batch 100
Avg. loss per last 100 batches: 6.146358
...
Epoch: 0: Step: 29401/29440, loss=6.202332, lr=0.000007
Validation: Epoch: 0 Step: 29440/29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.715631 sec., loss=0.118019 
Eval step: 199 , used_time=35.508732 sec., loss=0.000000 
Eval step: 299 , used_time=53.322032 sec., loss=0.975496 
Eval step: 399 , used_time=71.156735 sec., loss=0.376352 
Eval step: 499 , used_time=89.044517 sec., loss=0.000088 
Eval step: 599 , used_time=106.960637 sec., loss=0.000000 
Eval step: 699 , used_time=124.844190 sec., loss=0.000639 
Eval step: 799 , used_time=142.751692 sec., loss=0.240759 
Eval step: 899 , used_time=160.649182 sec., loss=0.093718 
Eval step: 999 , used_time=178.555661 sec., loss=2.591503 
Eval step: 1099 , used_time=196.447839 sec., loss=0.003563 
Eval step: 1199 , used_time=214.350777 sec., loss=4.385298 
Eval step: 1299 , used_time=232.254511 sec., loss=1.744909 
Eval step: 1399 , used_time=250.164419 sec., loss=0.390905 
Eval step: 1499 , used_time=268.048638 sec., loss=2.424854 
Eval step: 1599 , used_time=285.939402 sec., loss=0.000078 
NLL Validation: loss = 0.964666. correct prediction ratio  5465/6516 ~  0.838705
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.0.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.0.29440
New Best validation checkpoint /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.0.29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.454361 sec., loss=0.118019 
Eval step: 199 , used_time=35.080369 sec., loss=0.000000 
Eval step: 299 , used_time=52.781430 sec., loss=0.975496 
Eval step: 399 , used_time=70.522686 sec., loss=0.376352 
Eval step: 499 , used_time=88.308615 sec., loss=0.000088 
Eval step: 599 , used_time=106.168578 sec., loss=0.000000 
Eval step: 699 , used_time=124.017505 sec., loss=0.000639 
Eval step: 799 , used_time=141.892720 sec., loss=0.240759 
Eval step: 899 , used_time=159.762845 sec., loss=0.093718 
Eval step: 999 , used_time=177.643622 sec., loss=2.591503 
Eval step: 1099 , used_time=195.507617 sec., loss=0.003563 
Eval step: 1199 , used_time=213.386917 sec., loss=4.385298 
Eval step: 1299 , used_time=231.270697 sec., loss=1.744909 
Eval step: 1399 , used_time=249.157400 sec., loss=0.390905 
Eval step: 1499 , used_time=267.017584 sec., loss=2.424854 
Eval step: 1599 , used_time=284.884477 sec., loss=0.000078 
NLL Validation: loss = 0.964666. correct prediction ratio  5465/6516 ~  0.838705
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.0.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.0.29440
Av Loss per epoch=0.636933
epoch total correct predictions=51794

***** Epoch 1 *****
...
Epoch: 1: Step: 29401/29440, loss=0.008938, lr=0.000003
Validation: Epoch: 1 Step: 29440/29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.723681 sec., loss=0.000000 
Eval step: 199 , used_time=35.485523 sec., loss=0.000000 
Eval step: 299 , used_time=53.261996 sec., loss=1.793509 
Eval step: 399 , used_time=71.101007 sec., loss=0.020972 
Eval step: 499 , used_time=88.970163 sec., loss=0.000000 
Eval step: 599 , used_time=106.833581 sec., loss=0.000002 
Eval step: 699 , used_time=124.678303 sec., loss=0.000000 
Eval step: 799 , used_time=142.544063 sec., loss=0.004837 
Eval step: 899 , used_time=160.408843 sec., loss=2.808296 
Eval step: 999 , used_time=178.280324 sec., loss=4.019991 
Eval step: 1099 , used_time=196.133461 sec., loss=0.003130 
Eval step: 1199 , used_time=214.001708 sec., loss=3.159004 
Eval step: 1299 , used_time=231.881492 sec., loss=2.635145 
Eval step: 1399 , used_time=249.744119 sec., loss=0.963694 
Eval step: 1499 , used_time=267.586157 sec., loss=1.437464 
Eval step: 1599 , used_time=285.444094 sec., loss=0.000374 
NLL Validation: loss = 0.899356. correct prediction ratio  5670/6516 ~  0.870166
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440
New Best validation checkpoint /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.561712 sec., loss=0.000000 
Eval step: 199 , used_time=35.237448 sec., loss=0.000000 
Eval step: 299 , used_time=52.970388 sec., loss=1.793509 
Eval step: 399 , used_time=70.731166 sec., loss=0.020972 
Eval step: 499 , used_time=88.610038 sec., loss=0.000000 
Eval step: 599 , used_time=106.472266 sec., loss=0.000002 
Eval step: 699 , used_time=124.316287 sec., loss=0.000000 
Eval step: 799 , used_time=142.188653 sec., loss=0.004837 
Eval step: 899 , used_time=160.055734 sec., loss=2.808296 
Eval step: 999 , used_time=177.927337 sec., loss=4.019991 
Eval step: 1099 , used_time=195.791230 sec., loss=0.003130 
Eval step: 1199 , used_time=213.659289 sec., loss=3.159004 
Eval step: 1299 , used_time=231.534983 sec., loss=2.635145 
Eval step: 1399 , used_time=249.418472 sec., loss=0.963694 
Eval step: 1499 , used_time=267.261050 sec., loss=1.437464 
Eval step: 1599 , used_time=285.120215 sec., loss=0.000374 
NLL Validation: loss = 0.899356. correct prediction ratio  5670/6516 ~  0.870166
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440
Av Loss per epoch=0.363017
epoch total correct predictions=55150

***** Epoch 2 *****
...

Epoch: 2: Step: 29401/29440, loss=0.000000, lr=0.000000
Validation: Epoch: 2 Step: 29440/29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.715817 sec., loss=0.000010 
Eval step: 199 , used_time=35.478746 sec., loss=0.000000 
Eval step: 299 , used_time=53.269142 sec., loss=2.994547 
Eval step: 399 , used_time=71.110545 sec., loss=0.010040 
Eval step: 499 , used_time=88.981745 sec., loss=0.020422 
Eval step: 599 , used_time=106.849028 sec., loss=0.000000 
Eval step: 699 , used_time=124.698054 sec., loss=0.000000 
Eval step: 799 , used_time=142.573614 sec., loss=0.014013 
Eval step: 899 , used_time=160.445224 sec., loss=0.225973 
Eval step: 999 , used_time=178.331551 sec., loss=2.218931 
Eval step: 1099 , used_time=196.185302 sec., loss=0.000000 
Eval step: 1199 , used_time=214.055724 sec., loss=4.342945 
Eval step: 1299 , used_time=231.930175 sec., loss=1.739365 
Eval step: 1399 , used_time=249.802148 sec., loss=0.878143 
Eval step: 1499 , used_time=267.646685 sec., loss=1.899799 
Eval step: 1599 , used_time=285.503991 sec., loss=0.003260 
NLL Validation: loss = 0.924181. correct prediction ratio  5739/6516 ~  0.880755
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.2.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.2.29440
NLL validation ...
Reading file /export/data/jzhoubu/workspace/DPR/data/data/retriever/nq-dev.json
Aggregated data size: 6515
Total cleaned data size: 6515
Eval step: 99 , used_time=17.545685 sec., loss=0.000010 
Eval step: 199 , used_time=35.214913 sec., loss=0.000000 
Eval step: 299 , used_time=52.955192 sec., loss=2.994547 
Eval step: 399 , used_time=70.729739 sec., loss=0.010040 
Eval step: 499 , used_time=88.584115 sec., loss=0.020422 
Eval step: 599 , used_time=106.462052 sec., loss=0.000000 
Eval step: 699 , used_time=124.322073 sec., loss=0.000000 
Eval step: 799 , used_time=142.211733 sec., loss=0.014013 
Eval step: 899 , used_time=160.098584 sec., loss=0.225973 
Eval step: 999 , used_time=177.995218 sec., loss=2.218931 
Eval step: 1099 , used_time=195.869485 sec., loss=0.000000 
Eval step: 1199 , used_time=213.755858 sec., loss=4.342945 
Eval step: 1299 , used_time=231.646454 sec., loss=1.739365 
Eval step: 1399 , used_time=249.538532 sec., loss=0.878143 
Eval step: 1499 , used_time=267.402575 sec., loss=1.899799 
Eval step: 1599 , used_time=285.275870 sec., loss=0.003260 
NLL Validation: loss = 0.924181. correct prediction ratio  5739/6516 ~  0.880755
Saved checkpoint at /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.2.29440
Saved checkpoint to /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.2.29440
Av Loss per epoch=0.246487
epoch total correct predictions=56370
Training finished. Best validation checkpoint /export/data/jzhoubu/workspace/DPR/output/dense_retriever_nq/dpr_biencoder.1.29440

My questions are: 1) The retriever acc is surprisingly 3 points higher than that reported in the paper, is it normal? 2) I am using the default parameter on a single GPU device. If I understand it correctly, my batch size is 2 which is far smaller than yours(128) while the results are still close. Is this normal? 3) If 2) is normal, does it mean the in-batch negative samples are not that helpful? This is quite contrary to Table 3, have I missed anything?

vlad-karpukhin commented 3 years ago

Hi,

  1. What we reported in the paper is the recall at top-100 k documents over ALL wikipedia passages and not "correct prediction ratio" over small list of candidates (in your case it is 2+2(hard negatives from 2 questions)=4 candidates). In order to get actual retrieval accuracy, you will need to generate all wikipedia embeddings and then run dense_retriever tool.
  2. batch size 2 will give you inferior model quality since the effective batch size is critical for the biencoder model training.
  3. See #1.
krishanudb commented 3 years ago

Hello, I have a query on somewhat similar lines.. In the paper, you have reported top-k accuracy, which was also defined in the paper as "the percentage of top 20/100 retrieved passages that contain the answer."

My doubt is, the top-20 accuracy should be higher than top-100 accuracy in that case, as the top documents should have higher likelihood of containing answers, and it should become lower as more distant (in the embedding space) documents are added to the list. However, in the results, the top-100 accuracy is more than top-20 accuracy, which seems counter intuitive.

Can you please help me out if with this doubt? Thank you, Krishanu.

jzhoubu commented 3 years ago

Thanks, @vlad-karpukhin . Following your comments, I use dense_retriever to evaluate the pre-trained model and below is the result:

Total data indexed 21015300
Data indexing completed.
Total encoded queries tensor torch.Size([3610, 768])
index search time: 3383.498359 sec.
Reading data from: /export/data/jzhoubu/workspace/DPR/data/data/wikipedia_split/psgs_w100.tsv
Matching answers in top docs...
Per question validation results len=3610
Validation results: top k documents hits [1560, 1983, 2181, 2294, 2389, 2466, 2531, 2580, 2624, 2668, 2692, 2721, 2743, 2765, 2788, 2808, 2823, 2834, 2854, 2860, 2873, 2885, 2896, 2911, 2920, 2930, 2940, 2947, 2960, 2966, 2972, 2980, 2980, 2989, 2994, 2996, 2997, 3000, 3007, 3010, 3011, 3011, 3014, 3016, 3019, 3021, 3025, 3026, 3028, 3032, 3036, 3038, 3039, 3040, 3040, 3043, 3046, 3047, 3048, 3051, 3053, 3056, 3056, 3057, 3059, 3059, 3061, 3064, 3065, 3070, 3070, 3072, 3076, 3079, 3079, 3080, 3081, 3083, 3084, 3086, 3089, 3090, 3091, 3091, 3091, 3091, 3091, 3093, 3093, 3093, 3094, 3094, 3095, 3098, 3099, 3101, 3101, 3103, 3107, 3108]
Validation results: top k documents hits accuracy [0.43213296398891965, 0.5493074792243767, 0.6041551246537397, 0.6354570637119114, 0.6617728531855955, 0.6831024930747922, 0.7011080332409972, 0.7146814404432132, 0.7268698060941828, 0.7390581717451523, 0.7457063711911357, 0.7537396121883656, 0.7598337950138504, 0.7659279778393352, 0.7722991689750692, 0.7778393351800554, 0.781994459833795, 0.7850415512465374, 0.7905817174515235, 0.7922437673130194, 0.7958448753462604, 0.7991689750692521, 0.8022160664819945, 0.8063711911357341, 0.8088642659279779, 0.8116343490304709, 0.814404432132964, 0.8163434903047091, 0.8199445983379502, 0.821606648199446, 0.8232686980609418, 0.8254847645429363, 0.8254847645429363, 0.8279778393351801, 0.8293628808864266, 0.8299168975069252, 0.8301939058171746, 0.8310249307479224, 0.8329639889196676, 0.8337950138504155, 0.8340720221606648, 0.8340720221606648, 0.8349030470914127, 0.8354570637119113, 0.8362880886426592, 0.8368421052631579, 0.8379501385041551, 0.8382271468144045, 0.838781163434903, 0.8398891966759002, 0.8409972299168975, 0.8415512465373961, 0.8418282548476455, 0.8421052631578947, 0.8421052631578947, 0.8429362880886426, 0.8437673130193906, 0.8440443213296399, 0.8443213296398892, 0.8451523545706371, 0.8457063711911358, 0.8465373961218836, 0.8465373961218836, 0.846814404432133, 0.8473684210526315, 0.8473684210526315, 0.8479224376731302, 0.8487534626038781, 0.8490304709141274, 0.850415512465374, 0.850415512465374, 0.8509695290858725, 0.8520775623268698, 0.8529085872576178, 0.8529085872576178, 0.853185595567867, 0.8534626038781163, 0.8540166204986149, 0.8542936288088643, 0.8548476454293629, 0.8556786703601108, 0.8559556786703602, 0.8562326869806094, 0.8562326869806094, 0.8562326869806094, 0.8562326869806094, 0.8562326869806094, 0.856786703601108, 0.856786703601108, 0.856786703601108, 0.8570637119113573, 0.8570637119113573, 0.8573407202216067, 0.8581717451523546, 0.8584487534626039, 0.8590027700831024, 0.8590027700831024, 0.8595567867036011, 0.8606648199445983, 0.8609418282548477]
Saved results * scores  to /export/data/jzhoubu/workspace/DPR/checkpoint/retriever/multiset/bert-base-encoder.cp.test.result.top100

The top100 retrieved acc is 0.861 which is matching the reported one.

vlad-karpukhin commented 3 years ago

Hi @krishanudb ,

"the percentage of top 20/100 retrieved passages that contain the answer." - it doesn't mean the relative amount of passages that contains answers. Most of the questions has just a single true positive passage. Our accuracy metric actually means there is AT LEAST ONE positive (which may be a false positive) passage among 20 top results.

krishanudb commented 3 years ago

Hi @vlad-karpukhin Okay. Its clear now. Thanks for the explanation.

vlad-karpukhin commented 3 years ago

Looks like it can be closed now.