facebookresearch / tart

Code and model release for the paper "Task-aware Retrieval with Instructions" by Asai et al.
Other
160 stars 11 forks source link

Unable to reproduce BEIR results #3

Open theartpiece opened 1 year ago

theartpiece commented 1 year ago

Hi @AkariAsai

I really love the work. It's so neat and inspirational.

I'm running into reproducibility issues.

I ran the following command- python eval_beir.py --dataset fiqa --output_dir eval_results/ --model_name_or_path facebook/contriever-msmarco --ce_model facebook/tart-full-flan-t5-xl --prompt "Find financial web article paragraph to answer"

But I get NDCG@10=30.4 while the paper reports 32.9 (on Contreiver-MS) and 41.8 (on Full FLAN T5)

Can you guide me to where I'm going wrong?

AkariAsai commented 1 year ago

Hi, thank you for reporting! When I was evaluating the tart-full-flan-t5-xl, I used the eval_cross_task.py script due to the slow inference of the eval_beir.py script. I will run the evaluation using the eval_beir.py script and try to reproduce the reported issue. Thank you so much!

DaehanKim commented 1 year ago

Hi, I thinkeval_beir.py script is a bit misleading since it allows cross-encoder setting. But under the hood, it's encoding query and document separately and use dot product to compute scores. (not exactly a cross-encoder!) That's why same results are not reproduced with eval_beir.py script.

theartpiece commented 1 year ago

@AkariAsai I'm still not able to reproduce results on many datasets. Even when I make suitable changes in the evaluation code (such as using evaluate_model_ce(.) instead of evaluate_model(.)) I am in fact getting results of TART-DUAL which is very strange.

AkariAsai commented 1 year ago

Hi so sorry for my slow response! I'd really appreciate it if you could give me some clarifications on the issue.

  1. @theartpiece Are you using the eval_beir.py script or eval_cross_task.py? I think the eval_beir.py script doesn't support ce_model. If you are using eval_beir.py with ce_model the ce_model value may not be passed.
  2. @theartpiece Could you share the code snippet on how you call evaluate_model_ce()? Regarding the evaluate_model_ce function, one tricky thing is you have to pass an instruction as ce_prompt, not `prompt (ref)
  3. @DaehanKim Are you using the eval_beir.py script to load a cross-encoder model using model_name_or_path? To load and run a cross-encoder, you need to use a different function evaluate_model_ce(). If you try to use a cross-encoder model as a bi-encoder, the model performance can get really low as the model is not trained to independently encode passage & queries.
theartpiece commented 1 year ago

Hi @AkariAsai this is the API call from eval_beir.py. Note that in the cmdline I pass the prompt in args.prompt attribute but that's not passed in the prompt argument of evaluate_mode_ce but in the ce_prompt (as expected). I use contreiver for filtering and your full-flan-t5 for reranking. I downloaded contreiver from your uwash webpage and full-flan-t5 from huggingface module.

            metrics = src.beir_utils.evaluate_model_ce(
                query_encoder=query_encoder,
                doc_encoder=doc_encoder,
                tokenizer=tokenizer,
                dataset=args.dataset,
                batch_size=args.per_gpu_batch_size,
                norm_query=args.norm_query,
                norm_doc=args.norm_doc,
                is_main=src.dist_utils.is_main(),
                split="dev" if args.dataset == "msmarco" else "test",
                score_function=args.score_function,
                beir_dir=args.beir_dir,
                save_results_path=args.save_results_path,
                lower_case=args.lower_case,
                normalize_text=args.normalize_text,
                ce_model_path=args.ce_model_name_or_path,
                ce_prompt=args.prompt
            )

and this the definition of evaluate_model_ce

def evaluate_model_ce(
    query_encoder,
    doc_encoder,
    tokenizer,
    dataset,
    batch_size=128,
    add_special_tokens=True,
    norm_query=False,
    norm_doc=False,
    is_main=True,
    split="test",
    score_function="dot",
    beir_dir="BEIR/datasets",
    save_results_path=None,
    lower_case=False,
    normalize_text=False,
    prompt=None,
    ce_prompt=None,
    ce_model_path=None,
    load_retrieval_results=False
):

    metrics = defaultdict(list)  # store final results

    if hasattr(query_encoder, "module"):
        query_encoder = query_encoder.module
    query_encoder.eval()

    if doc_encoder is not None:
        if hasattr(doc_encoder, "module"):
            doc_encoder = doc_encoder.module
        doc_encoder.eval()
    else:
        doc_encoder = query_encoder

    dmodel = DenseRetrievalExactSearch(
        DenseEncoderModel(
            query_encoder=query_encoder,
            doc_encoder=doc_encoder,
            tokenizer=tokenizer,
            add_special_tokens=add_special_tokens,
            norm_query=norm_query,
            norm_doc=norm_doc,
            lower_case=lower_case,
            normalize_text=normalize_text,
            prompt=prompt,
        ),
        batch_size=batch_size,
    )
    retriever = EvaluateRetrieval(dmodel, score_function=score_function)
    data_path = os.path.join(beir_dir, dataset)
    # cross_encoder_model = CrossEncoder(model_path='/checkpoint/akariasai/ranker/bert_base_st_ranker_manual_all_with_instructions_hard_negatives_instructions', num_labels=2)
    # reranker = Rerank('/checkpoint/akariasai/ranker/bert_base_st_ranker_manual_all_with_instructions_hard_negatives_instructions_instruction_unfollowing/checkpoint-50000/', batch_size=100)
    reranker = Rerank(ce_model_path, batch_size=100)

    if not os.path.isdir(data_path) and is_main:
        url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
        data_path = beir.util.download_and_unzip(url, beir_dir)
    dist_utils.barrier()

    if not dataset == "cqadupstack":
        corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
        if load_retrieval_results is True:
            results = json.load(open("retriever_results_contriever_{}.json".format(dataset)))
        else:
            results = retriever.retrieve(corpus, queries)
            with open("retriever_results_contriever_{}.json".format(dataset), "w") as outfile:
                json.dump(results, outfile)

        print("start reranking")
        rerank_results = reranker.rerank(corpus, queries, results, top_k=100, prompt=ce_prompt)

        if is_main:
            ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, retriever.k_values)
            for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
                if isinstance(metric, str):
                    metric = retriever.evaluate_custom(qrels, rerank_results, retriever.k_values, metric=metric)
                for key, value in metric.items():
                    metrics[key].append(value)
            if save_results_path is not None:
                torch.save(rerank_results, f"{save_results_path}")

    elif dataset == "cqadupstack":  # compute macroaverage over datasetds
        paths = glob.glob(data_path)
        for path in paths:
            corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
            results = retriever.retrieve(corpus, queries)
            rerank_results = reranker.rerank(corpus, queries, results, top_k=100, prompt=ce_prompt)
            if is_main:
                ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, retriever.k_values)
                for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
                    if isinstance(metric, str):
                        metric = retriever.evaluate_custom(qrels, rerank_results, retriever.k_values, metric=metric)
                    for key, value in metric.items():
                        metrics[key].append(value)
        for key, value in metrics.items():
            assert (
                len(value) == 12
            ), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"

    metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}

    return metrics
theartpiece commented 1 year ago

Hi @AkariAsai Were you able to find any error?

theartpiece commented 1 year ago

Hi @AkariAsai following up.. Kind regards

theartpiece commented 1 year ago

Hi @AkariAsai following up.. Kind regards