IBM / kgi-slot-filling

This is the code for our KILT leaderboard submissions (KGI + Re2G models).
Apache License 2.0
150 stars 13 forks source link

Reproduce the results on Trivia QA dataset #12

Open HansiZeng opened 9 months ago

HansiZeng commented 9 months ago

When I reproduce the Re2G in the Trivia QA dataset. I couldn't reproduce the results of the generation model in the second stage. In the second stage, the generation model only uses the retrieved passages from the trained DPR, very similar to the KGI paper. I use the provided command for training:

python generation/kgi_train.py \
  --kilt_data ${dataset}_training \
  --output models/RAG/${dataset}_dpr_rag \
  --corpus_endpoint kilt_passages_${dataset} \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/${dataset}_dpr_rag_init \
  --warmup_fraction 0.05  --num_train_epochs 2 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64
I got the following performance R-Prec Recall@5 Accuracy F1 KILT-AC KILT-F1
57.82 62.13 37.51 57.35 26.14 38.52
Copy the KGI_0 results from the paper R-Prec Recall@5 Accuracy F1 KILT-AC KILT-F1
60.49 63.54 60.99 66.55 42.85 46.08

The retrieval metrics (R-Prec, Recall@5) seem close to the KGI model, but the generation metrics (Accuracy, F1, Kilt-AC, Kilt-F1) are far worse than the KGI model.

m-r-g commented 9 months ago

I think the DPR model is not as good. Here's the run command I used for DPR on TriviaQA:

python ${PYTHONPATH}/dpr/biencoder_trainer.py \
--train_dir /data/KILT/qa/${DS}/dpr_training_data \
--output_dir /data/KILT/qa/${DS}/models/dpr_e3 \
--num_train_epochs 3 \
--num_instances 89273 \
--encoder_gpu_train_limit 16 \
--max_grad_norm 1.0 --learning_rate 5e-5 \
--full_train_batch_size 128

I think the only difference with the default is training for 3 epochs.