IBM / kgi-slot-filling

This is the code for our KILT leaderboard submissions (KGI + Re2G models).
Apache License 2.0
150 stars 13 forks source link

Reproduce the results on KILT wizard of wikipedia #6

Open heya5 opened 2 years ago

heya5 commented 2 years ago
Hello, I got similar results in development set after run kgi_train.py . Rprec recall@5 rougel f1 KILT-rougel KILT-f1
0.502947 0.690242 0.160686 0.182609 0.095659 0.108410

But after I run reranker_train.py and rerank_apply.py, I get a results which seems worse than the results in the table 2 of Re2G paper.

Rprec recall@5
my experiments 47.38 72.04
re2g paper 55.50 74.98

I think re2g is a solid work. Could you please give me some advices to reproduce the result?

I run the following command to train the reranker

python reranker/reranker_train.py \
  --model_type bert --model_name_or_path nboost/pt-bert-base-uncased-msmarco --do_lower_case \
  --positive_pids ${dataset}/train_positive_pids.jsonl \
  --initial_retrieval  predictions/dpr_bm25/wow_train.jsonl  \
  --num_train_epochs 2 \
  --output_dir models/reranker_stage1
Fu-Dayuan commented 1 year ago

I have just solved that problem on triviaQA. In the paper's setting, the reranker's training epoch is 1 and the warmup is 10%. So the second to last line should be --num_train_epochs 1 --warmup_fraction 0.1 \ I use that command to reproduce on triviaQA and get a similar result(even higher than the result in the paper). wish this can help you!

HansiZeng commented 9 months ago

@Fu-Dayuan When you reproduce the Trivia QA dataset, do you remember the results of the RAG model in the second stage (in this stage, the generation model only uses the retrieved passages from DPR, very similar to KGI). The training command is:

python generation/kgi_train.py \
  --kilt_data ${dataset}_training \
  --output models/RAG/${dataset}_dpr_rag \
  --corpus_endpoint kilt_passages_${dataset} \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/${dataset}_dpr_rag_init \
  --warmup_fraction 0.05  --num_train_epochs 2 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64
I got the following performance R-Prec Recall@5 Accuracy F1 KILT-AC KILT-F1
578 .621 .375 .573 .261 .385

The retrieval metrics are close to the KGI performance, but generation metrics are far worse than the KGI performance.