Open theartpiece opened 1 year ago
Hi, thank you for reporting! When I was evaluating the tart-full-flan-t5-xl, I used the eval_cross_task.py
script due to the slow inference of the eval_beir.py
script.
I will run the evaluation using the eval_beir.py
script and try to reproduce the reported issue. Thank you so much!
Hi, I thinkeval_beir.py
script is a bit misleading since it allows cross-encoder setting.
But under the hood, it's encoding query and document separately and use dot product to compute scores. (not exactly a cross-encoder!)
That's why same results are not reproduced with eval_beir.py
script.
@AkariAsai I'm still not able to reproduce results on many datasets. Even when I make suitable changes in the evaluation code (such as using evaluate_model_ce(.) instead of evaluate_model(.)) I am in fact getting results of TART-DUAL which is very strange.
Hi so sorry for my slow response! I'd really appreciate it if you could give me some clarifications on the issue.
eval_beir.py
script or eval_cross_task.py
? I think the eval_beir.py
script doesn't support ce_model
. If you are using eval_beir.py
with ce_model
the ce_model
value may not be passed. evaluate_model_ce()
? Regarding the evaluate_model_ce
function, one tricky thing is you have to pass an instruction as ce_prompt
, not `prompt (ref) eval_beir.py
script to load a cross-encoder model using model_name_or_path
? To load and run a cross-encoder, you need to use a different function evaluate_model_ce(). If you try to use a cross-encoder model as a bi-encoder, the model performance can get really low as the model is not trained to independently encode passage & queries. Hi @AkariAsai this is the API call from eval_beir.py. Note that in the cmdline I pass the prompt in args.prompt attribute but that's not passed in the prompt argument of evaluate_mode_ce but in the ce_prompt (as expected). I use contreiver for filtering and your full-flan-t5 for reranking. I downloaded contreiver from your uwash webpage and full-flan-t5 from huggingface module.
metrics = src.beir_utils.evaluate_model_ce(
query_encoder=query_encoder,
doc_encoder=doc_encoder,
tokenizer=tokenizer,
dataset=args.dataset,
batch_size=args.per_gpu_batch_size,
norm_query=args.norm_query,
norm_doc=args.norm_doc,
is_main=src.dist_utils.is_main(),
split="dev" if args.dataset == "msmarco" else "test",
score_function=args.score_function,
beir_dir=args.beir_dir,
save_results_path=args.save_results_path,
lower_case=args.lower_case,
normalize_text=args.normalize_text,
ce_model_path=args.ce_model_name_or_path,
ce_prompt=args.prompt
)
and this the definition of evaluate_model_ce
def evaluate_model_ce(
query_encoder,
doc_encoder,
tokenizer,
dataset,
batch_size=128,
add_special_tokens=True,
norm_query=False,
norm_doc=False,
is_main=True,
split="test",
score_function="dot",
beir_dir="BEIR/datasets",
save_results_path=None,
lower_case=False,
normalize_text=False,
prompt=None,
ce_prompt=None,
ce_model_path=None,
load_retrieval_results=False
):
metrics = defaultdict(list) # store final results
if hasattr(query_encoder, "module"):
query_encoder = query_encoder.module
query_encoder.eval()
if doc_encoder is not None:
if hasattr(doc_encoder, "module"):
doc_encoder = doc_encoder.module
doc_encoder.eval()
else:
doc_encoder = query_encoder
dmodel = DenseRetrievalExactSearch(
DenseEncoderModel(
query_encoder=query_encoder,
doc_encoder=doc_encoder,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
norm_query=norm_query,
norm_doc=norm_doc,
lower_case=lower_case,
normalize_text=normalize_text,
prompt=prompt,
),
batch_size=batch_size,
)
retriever = EvaluateRetrieval(dmodel, score_function=score_function)
data_path = os.path.join(beir_dir, dataset)
# cross_encoder_model = CrossEncoder(model_path='/checkpoint/akariasai/ranker/bert_base_st_ranker_manual_all_with_instructions_hard_negatives_instructions', num_labels=2)
# reranker = Rerank('/checkpoint/akariasai/ranker/bert_base_st_ranker_manual_all_with_instructions_hard_negatives_instructions_instruction_unfollowing/checkpoint-50000/', batch_size=100)
reranker = Rerank(ce_model_path, batch_size=100)
if not os.path.isdir(data_path) and is_main:
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
data_path = beir.util.download_and_unzip(url, beir_dir)
dist_utils.barrier()
if not dataset == "cqadupstack":
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
if load_retrieval_results is True:
results = json.load(open("retriever_results_contriever_{}.json".format(dataset)))
else:
results = retriever.retrieve(corpus, queries)
with open("retriever_results_contriever_{}.json".format(dataset), "w") as outfile:
json.dump(results, outfile)
print("start reranking")
rerank_results = reranker.rerank(corpus, queries, results, top_k=100, prompt=ce_prompt)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, retriever.k_values)
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
if isinstance(metric, str):
metric = retriever.evaluate_custom(qrels, rerank_results, retriever.k_values, metric=metric)
for key, value in metric.items():
metrics[key].append(value)
if save_results_path is not None:
torch.save(rerank_results, f"{save_results_path}")
elif dataset == "cqadupstack": # compute macroaverage over datasetds
paths = glob.glob(data_path)
for path in paths:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
results = retriever.retrieve(corpus, queries)
rerank_results = reranker.rerank(corpus, queries, results, top_k=100, prompt=ce_prompt)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, rerank_results, retriever.k_values)
for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
if isinstance(metric, str):
metric = retriever.evaluate_custom(qrels, rerank_results, retriever.k_values, metric=metric)
for key, value in metric.items():
metrics[key].append(value)
for key, value in metrics.items():
assert (
len(value) == 12
), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric"
metrics = {key: 100 * np.mean(value) for key, value in metrics.items()}
return metrics
Hi @AkariAsai Were you able to find any error?
Hi @AkariAsai following up.. Kind regards
Hi @AkariAsai following up.. Kind regards
Hi @AkariAsai
I really love the work. It's so neat and inspirational.
I'm running into reproducibility issues.
I ran the following command- python eval_beir.py --dataset fiqa --output_dir eval_results/ --model_name_or_path facebook/contriever-msmarco --ce_model facebook/tart-full-flan-t5-xl --prompt "Find financial web article paragraph to answer"
But I get NDCG@10=30.4 while the paper reports 32.9 (on Contreiver-MS) and 41.8 (on Full FLAN T5)
Can you guide me to where I'm going wrong?