facebookresearch / DPR

Dense Passage Retriever - is a set of tools and models for open domain Q&A task.
Other
1.69k stars 298 forks source link

Cannot Reproduce Table 2: Different Result #108

Closed robinsongh381 closed 3 years ago

robinsongh381 commented 3 years ago

Hello ! First of all, thank you for sharing amazing work.

I have been trying to reproduce the Table 2 from the paper, especially for results on NQ and Trivia dataset. Below are the summary of steps I have taken and consequent result.

Step1. Dowload dataset

Step2. Retriever training

Initialized host brain-cluster-gpu8.dakao.io as d.rank -1 on device=cuda, n_gpu=8, world size=1
16-bits training: False
 **************** CONFIGURATION ****************
adam_betas                     -->   (0.9, 0.999)
adam_eps                       -->   1e-08
batch_size                     -->   64
checkpoint_file_name           -->   dpr_biencoder
dev_batch_size                 -->   8
dev_file                       -->   ./data/data/retriever/nq*dev.json
device                         -->   cuda
distributed_world_size         -->   1
do_lower_case                  -->   False
dropout                        -->   0.1
encoder_model_type             -->   hf_bert
eval_per_epoch                 -->   1
fix_ctx_encoder                -->   False
fp16                           -->   False
fp16_opt_level                 -->   O1
global_loss_buf_sz             -->   150000
gradient_accumulation_steps    -->   1
hard_negatives                 -->   1
learning_rate                  -->   1e-05
local_rank                     -->   -1
log_batch_step                 -->   100
max_grad_norm                  -->   1.0
model_file                     -->   None
n_gpu                          -->   8
no_cuda                        -->   False
num_train_epochs               -->   40.0
other_negatives                -->   0
output_dir                     -->   ./checkpoint/nq
pretrained_file                -->   None
pretrained_model_cfg           -->   bert-base-uncased
projection_dim                 -->   0
seed                           -->   0
sequence_length                -->   512
shuffle_positive_ctx           -->   False
train_file                     -->   ./data/data/retriever/nq*train.json
train_files_upsample_rates     -->   None
train_rolling_loss_step        -->   1000
val_av_rank_bsz                -->   128
val_av_rank_hard_neg           -->   30
val_av_rank_max_qs             -->   10000
val_av_rank_other_neg          -->   30
val_av_rank_start_epoch        -->   10000
warmup_steps                   -->   100
weight_decay                   -->   0.0
 **************** CONFIGURATION ****************
***** Initializing components for training *****
Checkpoint files []
Reading file ./data/data/retriever/nq-train.json
Aggregated data size: 58880
Total cleaned data size: 58880
  Total iterations per epoch=920
 Total updates=36800
  Eval step = 920

Step3. Retriever inference

Initialized host gpu-cloud-node36.dakao.io as d.rank -1 on device=cuda, n_gpu=8, world size=1
16-bits training: False
Reading saved model from ./checkpoint/nq/dpr_biencoder.37.920
model_state_dict keys odict_keys(['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 'encoder_params'])
Overriding args parameter value from checkpoint state. Param = pretrained_model_cfg, value = bert-base-uncased
Overriding args parameter value from checkpoint state. Param = encoder_model_type, value = hf_bert
Overriding args parameter value from checkpoint state. Param = sequence_length, value = 512
 **************** CONFIGURATION ****************
batch_size                     -->   4096
ctx_file                       -->   ./data/data/wikipedia_split/psgs_w100.tsv
device                         -->   cuda
distributed_world_size         -->   1
do_lower_case                  -->   False
encoder_model_type             -->   hf_bert
fp16                           -->   False
fp16_opt_level                 -->   O1
local_rank                     -->   -1
model_file                     -->   ./checkpoint/nq/dpr_biencoder.37.920
n_gpu                          -->   8
no_cuda                        -->   False
num_shards                     -->   1
out_file                       -->   ./checkpoint/nq/embed_epoch_37
pretrained_file                -->   None
pretrained_model_cfg           -->   bert-base-uncased
projection_dim                 -->   0
sequence_length                -->   512
shard_id                       -->   0

Step 4. Retriever validation against the entire set of documents

Initialized host gpu-cloud-node36.dakao.io as d.rank -1 on device=cuda, n_gpu=8, world size=1
16-bits training: False
 **************** CONFIGURATION ****************
batch_size                     -->   128
ctx_file                       -->   data/data/wikipedia_split/psgs_w100.tsv
device                         -->   cuda
distributed_world_size         -->   1
do_lower_case                  -->   False
encoded_ctx_file               -->   checkpoint/nq/embed_epoch_37_0.pkl
encoder_model_type             -->   None
fp16                           -->   False
fp16_opt_level                 -->   O1
hnsw_index                     -->   False
index_buffer                   -->   50000
local_rank                     -->   -1
match                          -->   string
model_file                     -->   checkpoint/nq/dpr_biencoder.37.920
n_docs                         -->   100
n_gpu                          -->   8
no_cuda                        -->   False
out_file                       -->   checkpoint/nq/eval_dev_epoch_37_top_100.tsv
pretrained_file                -->   None
pretrained_model_cfg           -->   None
projection_dim                 -->   0
qa_file                        -->   data/data/retriever/qas/nq-dev.csv
save_or_load_index             -->   False
sequence_length                -->   512
validation_workers             -->   16
Total encoded queries tensor torch.Size([8757, 768])
index search time: 1273.607778 sec.
Reading data from: data/data/wikipedia_split/psgs_w100.tsv
Matching answers in top docs...
Per question validation results len=8757
Validation results: top k documents hits [1869, 2525, 2971, 3256, 3477, 3677, 3825, 3957, 4074, 4176, 4281, 4400, 4486, 4559, 4622, 4684, 4750
, 4799, 4854, 4898, 4941, 4976, 5020, 5057, 5087, 5128, 5163, 5192, 5236, 5270, 5287, 5305, 5325, 5346, 5382, 5407, 5420, 5437, 5456, 5471, 54
85, 5505, 5522, 5542, 5560, 5582, 5598, 5612, 5626, 5647, 5664, 5675, 5696, 5706, 5716, 5732, 5746, 5764, 5784, 5798, 5813, 5826, 5837, 5849,
5861, 5873, 5882, 5890, 5896, 5909, 5915, 5924, 5939, 5952, 5963, 5973, 5981, 5989, 5997, 6010, 6020, 6032, 6043, 6054, 6063, 6070, 6076, 6082
, 6088, 6096, 6102, 6105, 6108, 6111, 6118, 6122, 6132, 6140, 6152, 6156]
Validation results: top k documents hits accuracy [0.21342925659472423, 0.28834075596665526, 0.33927143999086445, 0.37181683224848694, 0.39705
37855429942, 0.41989265730272923, 0.4367934224049332, 0.4518670777663583, 0.46522781774580335, 0.47687564234326824, 0.48886605001712913, 0.502
4551787141716, 0.5122758935708576, 0.5206120817631609, 0.5278063263674775, 0.5348863766129953, 0.5424232042937079, 0.548018727874843, 0.554299
4176087701, 0.5593239693959119, 0.5642343268242549, 0.5682311293822085, 0.5732556811693502, 0.5774808724449012, 0.5809067032088615, 0.58558867
19196072, 0.5895854744775608, 0.5928971108827223, 0.5979216626698641, 0.6018042708690191, 0.6037455749685966, 0.6058010734269728, 0.6080849606
029463, 0.6104830421377184, 0.6145940390544707, 0.6174488980244376, 0.6189334246888204, 0.6208747287883979, 0.6230444216055727, 0.624757336987
5529, 0.6263560580107342, 0.6286399451867077, 0.6305812492862852, 0.6328651364622587, 0.6349206349206349, 0.6374329108142058, 0.63926002055498
45, 0.640858741578166, 0.6424574626013475, 0.6448555441361197, 0.6467968482356972, 0.6480529861824826, 0.6504510677172548, 0.6515930113052415,
 0.6527349548932283, 0.6545620646340071, 0.6561607856571885, 0.6582162841155647, 0.6605001712915382, 0.6620988923147196, 0.6638118076966998, 0
.6652963343610826, 0.666552472307868, 0.6679228046134521, 0.6692931369190362, 0.6706634692246203, 0.6716912184538084, 0.6726047733241978, 0.67
32899394769898, 0.6747744661413726, 0.6754596322941646, 0.6764873815233527, 0.6782002969053329, 0.6796848235697156, 0.6809409615165011, 0.6820
829051044879, 0.6829964599748772, 0.6839100148452666, 0.684823569715656, 0.6863080963800389, 0.6874500399680256, 0.6888203722736097, 0.6900765
102203951, 0.6913326481671805, 0.6923603973963686, 0.6931597579079594, 0.6938449240607514, 0.6945300902135435, 0.6952152563663355, 0.696128811
2367249, 0.696813977389517, 0.697156560465913, 0.697499143542309, 0.697841726618705, 0.6986410871302957, 0.6990978645654905, 0.700239808153477
2, 0.7011533630238667, 0.7025236953294507, 0.7029804727646455]
Saved results * scores  to checkpoint/nq/eval_dev_epoch_37_top_100.tsv

Issue

[Updated]


## Original code

            sequence_output, pooled_output = super().forward(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
            )

## Modified code

            output = super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            sequence_output = output.last_hidden_state

Please let me know if I did anything wrong Thank you

vlad-karpukhin commented 3 years ago

Hi @robinsongh381 ,

"as d.rank -1 on device=cuda, n_gpu=8, world size=1"

According to the training step above - you are using just DatParallel mode on 8 GPU instead of DistributedDataParallel mode on 8 gpus like:

"python -m torch.distributed.launch --nproc_per_node=8 train_dense_encoder.py ..."

Although in principle they should more or less be equivalents if using same num of gpus, we trained the biencoder using DDP mode only and unaware of specific DM mode parameters to get the on par results.

robinsongh381 commented 3 years ago

Thank you for reply I will re-train with DDP and see how it goes !

robinsongh381 commented 3 years ago

@vlad-karpukhin As you suggested, I changed training config as follows

 **************** CONFIGURATION ****************
adam_betas                     -->   (0.9, 0.999)
adam_eps                       -->   1e-08
batch_size                     -->   16
checkpoint_file_name           -->   dpr_biencoder
dev_batch_size                 -->   16
dev_file                       -->   data/data/retriever/nq-dev.json
device                         -->   cuda:0
distributed_world_size         -->   8
do_lower_case                  -->   True
dropout                        -->   0.1
encoder_model_type             -->   hf_bert
eval_per_epoch                 -->   1
fix_ctx_encoder                -->   False
fp16                           -->   False
fp16_opt_level                 -->   O1
global_loss_buf_sz             -->   150000
gradient_accumulation_steps    -->   1
hard_negatives                 -->   1
learning_rate                  -->   2e-05
local_rank                     -->   0
log_batch_step                 -->   100
max_grad_norm                  -->   2.0
model_file                     -->   None
n_gpu                          -->   1
no_cuda                        -->   False
num_train_epochs               -->   40.0
other_negatives                -->   0
output_dir                     -->   ./checkpoint/nq_best
pretrained_file                -->   None
pretrained_model_cfg           -->   bert-base-uncased
projection_dim                 -->   0
seed                           -->   12345
sequence_length                -->   256
shuffle_positive_ctx           -->   False
train_file                     -->   data/data/retriever/nq-train.json
train_files_upsample_rates     -->   None
train_rolling_loss_step        -->   500
val_av_rank_bsz                -->   128
val_av_rank_hard_neg           -->   30
val_av_rank_max_qs             -->   10000
val_av_rank_other_neg          -->   30
val_av_rank_start_epoch        -->   30
warmup_steps                   -->   1237
weight_decay                   -->   0.0

After training with above config, I conducted same procedures for Step 3 and Step 4 which gave out

Validation results: top k documents hits accuracy [0.4442160557268471, 0.5512161699212059, 0.6067146282973621, 0.6409729359369647, 0.666209889231472, 0.686
8790681740322, 0.7012675573826653, 0.7134863537741235, 0.7247915952951924, 0.7346123101518784, 0.7432910814205779, 0.7501427429484984, 0.7563092383236268,
0.7604202352403792, 0.7648738152335275, 0.7680712572798903, 0.7728674203494348, 0.776293251113395, 0.7799474705949526, 0.7824597464885235, 0.78451524494689
96, 0.7873701039168665, 0.7901107685280347, 0.7920520726276122, 0.7947927372387804, 0.7966198469795592, 0.7992463172319287, 0.8010734269727076, 0.803014731
072285, 0.8044992577366679, 0.8062121731186479, 0.8080392828594267, 0.8096380038826082, 0.8107799474705949, 0.8126070572113737, 0.8135206120817632, 0.81443
41669521525, 0.816032887975334, 0.8168322484869247, 0.8184309695101062, 0.8194587187392943, 0.8204864679684823, 0.8214000228388718, 0.8220851889916638, 0.8
234555212972479, 0.82414068745004, 0.8245974648852347, 0.8250542423204293, 0.8263103802672148, 0.8266529633436108, 0.8279091012903963, 0.8282516843667923,
0.8285942674431883, 0.829393627954779, 0.8301929884663698, 0.8304213771839671, 0.8308781546191618, 0.8319059038483498, 0.832591070001142, 0.833276236153934
, 0.83361881923033, 0.8340755966655248, 0.8347607628183168, 0.8353317346123101, 0.8361310951239008, 0.8364736782002969, 0.836816261276693, 0.83772981614708
23, 0.8379582048646796, 0.8386433710174718, 0.8392143428114651, 0.8400137033230558, 0.8402420920406531, 0.8403562863994518, 0.8405846751170493, 0.841041452
5522439, 0.8411556469110426, 0.84138403562864, 0.8419550074226333, 0.8425259792166268, 0.8428685622930228, 0.8437821171634121, 0.8444672833162041, 0.844695
6720338016, 0.8450382551101976, 0.8458376156217883, 0.8460660043393856, 0.8464085874157816, 0.8467511704921777, 0.8475505310037684, 0.8475505310037684, 0.8
482356971565604, 0.8491492520269499, 0.8493776407445472, 0.8499486125385406, 0.8501770012561379, 0.8504053899737353, 0.8505195843325339, 0.8507479730501313]

This seems to be much better than before ! However, I have a doubt with the result

Meanwhile I have another question, if you do not mind

I was wondering why the two figures are not the same given that it looks as if they were both trained with (presumably) batch size of 128 + 1 hard negative

vlad-karpukhin commented 3 years ago

If you used batch_size=16 on 8 gpus in DDP mode that actually means you had 16 8 2(+ hard neg) effective batch size which corresponds to our last line in Table 3. Also, our suggested best hyper-params settings should actually give better IR results vs those reported in the paper (we just forgot to update them in April vs original results obtained on Nov).