Open YiweiJiang2015 opened 2 years ago
Hi Yawei, thank you for the question!
As indicated in README, we used Facebook DPR, specifically
git checkout 49e5838f94ffced8392be750ded2a8fa4a14b5cf
with default configurations including the effective batch size of 128. Please refer to their official scripts for more details.
We provided the script for creating positive and negative examples for MultiDoc2Dial.
Please let us know if you have any questions. Thanks!
Here are some specs about the DPR biencoder training
batch_size: 128
dev_batch_size: 64
adam_eps: 1e-8
adam_betas: (0.9, 0.999)
max_grad_norm: 2.0
log_batch_step: 1
train_rolling_loss_step: 20
weight_decay: 0.0
learning_rate: 2e-5
# Linear warmup over warmup_steps.
warmup_steps: 200
# Number of updates steps to accumulate before performing a backward/update pass.
gradient_accumulation_steps: 1
# Total number of training epochs to perform.
num_train_epochs: 50
eval_per_epoch: 1
hard_negatives: 1
other_negatives: 0
val_av_rank_hard_neg: 30
val_av_rank_other_neg: 30
val_av_rank_bsz: 128
val_av_rank_max_qs: 10000
Thanks for sharing the parameters. Now I can reproduce the results in Table 5.
Is it possible to reproduce the results with sivasankalpp/dpr-multidoc2dial-structure-question
and sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder
checkpoints?
Using the following leads to significantly lower result:
tokenizer_contex = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder")
tokenizer_question = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder")
model_context = DPRContextEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder")
model_question = AutoModel.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder")
input_ids = tokenizer_question(inp_q, return_tensors='pt', truncation=True)["input_ids"]
query_emb = model_question(input_ids).pooler_output
input_ids = tokenizer_contex(inp_p, truncation=True, return_tensors='pt', )["input_ids"]
passage_emb = model_context(input_ids).pooler_output
torch.dot(query_emb, passage_emb)
Hi,
It is really a nice and cool dataset. I am wondering how to reproduce the results in Table 5 from your paper, i.e. the retrieval results on validation set. I searched your codebase but found no script to do that (or I missed anything?).
Thanks for any help.
Yiwei