ArvinZhuang / DSI-QG

The official repository for "Bridging the Gap Between Indexing and Retrieval for Differentiable Search Index with Query Generation", Shengyao Zhuang, Houxing Ren, Linjun Shou, Jian Pei, Ming Gong, Guido Zuccon and Daxin Jiang.
MIT License
114 stars 19 forks source link

Question when reproducing experiment #10

Closed gcalabria closed 1 year ago

gcalabria commented 1 year ago

First of all, I would like to thank you for making the code for work available and I also say that I really liked your paper. It is very interesting.

I am currently writing my master's thesis and I would like to use part of your code to build my own mathematical IR system. So, the first thing I've done is to try to run your scripts and see if my results match yours.

So what I did was to run the get_data.sh script and then the scripts of steps 2 and 3 of the README file. The model is still being trained but I am skeptical about the results I am getting until now.

Can you please confirm if the graphs below are in accordance with your results?

Thank you for your attention :) Screenshot from 2023-05-10 10-05-07 Screenshot from 2023-05-10 10-05-17 Screenshot from 2023-05-10 10-05-28

ArvinZhuang commented 1 year ago

Hi @g-lopes, thank you for your kind words! Im glad I can help :)

It looks like your plots converging much slower than mine.. May I ask which setting you are exactly running? like msmarco or xorqa dataset? Are you running DSI-QG or the original DSI? how many gpus and batch sizes are used?

gcalabria commented 1 year ago

That was quick :)

First, I generated the queries using the following command:

python3 -m torch.distributed.launch --nproc_per_node=8 run.py \
        --task generation \
        --model_name castorini/doc2query-t5-large-msmarco \
        --per_device_eval_batch_size 32 \
        --run_name docTquery-XORQA-generation \
        --max_length 256 \
        --valid_file data/xorqa_data/100k/xorqa_corpus.tsv \
        --output_dir temp \
        --dataloader_num_workers 10 \
        --report_to wandb \
        --logging_steps 100 \
        --num_return_sequences 10

After that, I started training the model running:

python3 -m torch.distributed.launch --nproc_per_node=8 run.py \
        --task "DSI" \
        --model_name "google/mt5-base" \
        --run_name "XORQA-100k-mt5-base-DSI-QG" \
        --max_length 32 \
        --train_file data/xorqa_data/100k/xorqa_corpus.tsv.q10.docTquery \
        --valid_file data/xorqa_data/100k/xorqa_DSI_dev_data.json \
        --output_dir "models/XORQA-100k-mt5-base-DSI-QG" \
        --learning_rate 0.0005 \
        --warmup_steps 100000 \
        --per_device_train_batch_size 32 \
        --per_device_eval_batch_size 32 \
        --evaluation_strategy steps \
        --eval_steps 1000 \
        --max_steps 500000 \
        --save_strategy steps \
        --dataloader_num_workers 10 \
        --save_steps 1000 \
        --save_total_limit 2 \
        --load_best_model_at_end \
        --gradient_accumulation_steps 1 \
        --report_to wandb \
        --logging_steps 100 \
        --dataloader_drop_last False \
        --metric_for_best_model Hits@10 \
        --greater_is_better True \
        --remove_prompt True

I am running DSI-QG, with 8 Nvidia A100 GPUs and batch_size equals to 32.

ArvinZhuang commented 1 year ago

Hi @g-lopes , according to your scripts, I think in step 2 you made the mistake of using castorini/doc2query-t5-large-msmarco, which is a English only QG model, to generate cross-lingual queries for XORQA dataset. You need to do the step 1 to train a cross-lingual QG model for step 2. Otherwise, you may want to test with ms marco dataset.

Or, you can try our trained cross-lingual QG models which I just uploaded on huggingface recently: https://huggingface.co/ielabgroup/xor-tydi-docTquery-mt5-large However, I used a slightly different prompt for training this model (check out the model card in the above page), if you want to directly use this model in step 2, you need to change the prompt in this line to match the prompt in the model card.

gcalabria commented 1 year ago

😃 thank you very much for your help. I will try to follow your instructions :) cheers!

gcalabria commented 1 year ago

@ArvinZhuang I've double-checked the scripts that I was using and I've achieved much better results. Screenshot from 2023-05-23 11-16-06 As you can see, I've achieved a hits@10 score of 0.8, and the training is not finished yet 😄 The results above were obtained using the MSMARCO dataset with 100k data points, 0 queries per document were generated using the castorini/doc2query-t5-large-msmarco model.

I am trying to convert the hits@10 scores that I have to absolute values so that I can compare my results with the ones in your paper. Can you please explain to me how the hit scores of Table 1 were computed?

I am multiplying my score by the number of documents in the evaluation dataset. In my case, the evaluation dataset has 6980 and the score is 0.8, so this would result in a score of 5504, which makes no sense to me.

Thanks for your help one more time :)

ArvinZhuang commented 1 year ago

@g-lopes Hi, the code for computing hits scores is in this function https://github.com/ArvinZhuang/DSI-QG/blob/479d8d74f9b6c83de99a6723d769e23f1402a8ba/run.py#L43

if you want to do inference with saved model checkpoints, you can try something like this https://github.com/ArvinZhuang/DSI-QG/issues/1#issuecomment-1343063412

gcalabria commented 1 year ago

@g-lopes Hi, the code for computing hits scores is in this function

https://github.com/ArvinZhuang/DSI-QG/blob/479d8d74f9b6c83de99a6723d769e23f1402a8ba/run.py#L43

if you want to do inference with saved model checkpoints, you can try something like this #1 (comment)

I saw this function. My idea of multiplying by the number of documents in the evaluation dataset came from line 61: https://github.com/ArvinZhuang/DSI-QG/blob/479d8d74f9b6c83de99a6723d769e23f1402a8ba/run.py#L61

Did the results from Table 1 of your paper come directly from this function? Or have you also multiplied by the size of the evaluation set?

Which test dataset did you use to generate these results?

ArvinZhuang commented 1 year ago

Ah, I see what you mean, yes my results are directly from this function, the numbers in the table are simply percentages, which you just need to multiply the numbers from the function by 100. @g-lopes So I think you have successfully reproduced my results :)

gcalabria commented 1 year ago

Aha! Great! Thank you very much :D