UKPLab / gpl

Powerful unsupervised domain adaptation method for dense retrieval. Requires only unlabeled corpus and yields massive improvement: "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval" https://arxiv.org/abs/2112.07577
Apache License 2.0
315 stars 39 forks source link

TSDAE + GPL and TAS-B + GPL #7

Open Yuan0320 opened 2 years ago

Yuan0320 commented 2 years ago

@kwang2049 Hi, thanks for your amazing work!

I wonder if the TSDAE + GPL mentioned in the paper refers to: fine-tuning the distilbert-base PLM with training order of (1)TSDAE on {dataset} -> (2) GPL on {dataset} ?

Thx.

kwang2049 commented 2 years ago

Hi @Yuan0320, thanks for your attention!

Sorry that the description about this in the paper is kinda misleading. It is composed of three training stages: (1) TSDAE on ${dataset} -> (2) MarginMSE on MSMARCO -> (3) GPL on ${dataset};.

It can also be understood as the TSDAE baseline (which was also trained on MS MARCO) in Table 1 + GPL training.

And actually, omitting step (2) leads to very little difference (cf. Table 4 in the paper). And from my observation, the major difference comes from Robust04 (around -1.0 nDCG@10 points).

Yuan0320 commented 2 years ago

@kwang2049 Thanks for the detailed response!

Is there no public code about (1) TSDAE on ${dataset} -> (2) MarginMSE on MSMARCO in this repo?

Thx.

kwang2049 commented 2 years ago

Sorry that there is currently no one-step solution for this. To reproduce it, please run these one by one: (1) Train a TSDAE model on the target corpus: https://github.com/UKPLab/sentence-transformers/tree/master/examples/unsupervised_learning/TSDAE. Note that if one wants to start from distilbert-base-uncased (i.e. the setting in the GPL paper), one needs some plug-in (since sadly the original distilbert does not support being as a decoder): https://github.com/UKPLab/sentence-transformers/issues/962#issuecomment-991603084. (2) Continue training with MarginMSE loss. This can be done with our GPL repo:

wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/msmarco.zip
unzip msmarco.zip
export dataset="msmarco"
python -m gpl.train \
    --path_to_generated_data "./$dataset" \
    --base_ckpt "YOUR_TSDAE_CHECKPOINT" \
    --gpl_score_function "dot" \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --retriever_score_functions "cos_sim" "cos_sim" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --use_train_qrels
    # --use_amp   # Use this for efficient training if the machine supports AMP

# One can run `python -m gpl.train --help` for the information of all the arguments
# Notice: Please do not set `qgen_prefix` and leave it as None by default.

So bascially, it will skip the query generation step (by setting use_train_qrels=True) and use the existing train qrels from MS MARCO.

Yuan0320 commented 2 years ago

@kwang2049 Thanks for the detailed pipeline for TSDAE + GPL!

BTW, if you have a backup of gpl-training-data.tsv file of MarginMSE on MSMARCO, could you share it with me if possible?

Thanks.

kwang2049 commented 2 years ago

You can use this file https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz instead. The format is different from our gpl-training-data.tsv, but thankfully @nreimers has ever wrapped the code of training this zero-shot baseline in a single file: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder_margin-mse.py. Actually, I also used similar code at the very beginning of the GPL project.

Yuan0320 commented 2 years ago

@kwang2049 Thanks!

Regarding the best method GPL/${dataset}-distilbert-tas-b-gpl-self_miner, is the following script correct? (Only change the base_ckpt and retrievers to 'msmarco-distilbert-base-tas-b', compared to the usage in your repo README.md)

export dataset="MYDATASET"
python -m gpl.train \
    --path_to_generated_data "./$dataset" \
    --base_ckpt "msmarco-distilbert-base-tas-b" \
    --gpl_score_function "dot" \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --new_size -1 \
    --queries_per_passage -1 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-tas-b"  \
    --retriever_score_functions "cos_sim" "cos_sim" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --qgen_prefix "qgen" \
    --do_evaluation \
kwang2049 commented 2 years ago

I think you just need to change --retriever_score_functions "cos_sim" "cos_sim" into --retriever_score_functions "dot", since you only have one negative miner and the miner is the TAB-B model, which was trained with dot product. Other part looks good to me.

ArtemisDicoTiar commented 1 year ago

Hello, @kwang2049 ! I appreciate your lovely work. While reading the GPL paper, I was quite confused what TSDAE and TSDAE + {something} on Table 9 of the paper. So I have searched about TSDAE on this repo and found this issue. Even I read the thread of this issue, I am still a bit confusing about TSDAE itself. Is 'TSDAE' that mentioned on Table9 is trained in the following pipeline? (1) TSDAE on ${dataset} -> (2) MarginMSE on MSMARCO and therefore, the TSDAE mentioned on Table 9 is corresponding to the Table 1's TSDAE (target → MS-MARCO)?

kwang2049 commented 1 year ago

Hi @ArtemisDicoTiar, your understanding is correct. "TSDAE"s in table 1 and 9 mean the same method, i.e. TSDAE (target → MS-MARCO). "Target" here means a certain dataset from the target domain (and we just trained with TSDAE on the corresponding unlabeled corpus).

GuodongFan commented 1 year ago

In the paper, I found that TSDAE is used to train the retrievers. (domain adaptation for dense retrieval) Is the TSDEA used for --base_ckpt? or --retrievers or can both? Thanks!