beir-cellar / beir

A Heterogeneous Benchmark for Information Retrieval. Easy to use, evaluate your models across 15+ diverse IR datasets.
http://beir.ai
Apache License 2.0
1.54k stars 182 forks source link

MonoT5 .predict gives NaN #126

Open maxmatical opened 1 year ago

maxmatical commented 1 year ago

Trying to use MonoT5 3B on some custom reranking tasks, the gist of the code is

model_name = "castorini/monot5-3b-msmarco-10k"
cross_encoder_model = MonoT5(model_name, token_false='▁false', token_true='▁true')

query, doc = "some query", "some doc"
cross_encoder_model.predict([(query, doc)], batch_size=1)

and returns [nan]

is this an expected case? should we treat this as a score of 0 in this case?

rahmanidashti commented 1 year ago

Hi, have you found the solution for this?

soyoung97 commented 11 months ago

Hi, It's not exactly the same issue, but I've noticed similar issues on the reranking task. For example, I'm using the following code.

def do_evaluation(queries, qrels, corpus, results=None):
    k_values = [1,5,10,20,50,100]
    retriever = EvaluateRetrieval()
    from beir.reranking.models import CrossEncoder, MonoT5
    from beir.reranking import Rerank
    cross_encoder_model = MonoT5(mode, token_false='▁false', token_true='▁true')
    print(f"Loading cross-encoder model from: {cross_encoder_model.model.config._name_or_path}")
    reranker = Rerank(cross_encoder_model, batch_size=256)
    results = reranker.rerank(corpus, queries, results, top_k=100) # outputs nan scores to results
    results = remove_nan(results) # manually assign score due to bug
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, k_values)

It may be too late, but the order inside the results is preserved, so the following code can be used as a quick workaround:

def remove_nan(results):
    new_res = {}
    for query_key in results.keys():
        out = {}
        for i, corpus_key in enumerate(results[query_key].keys()):
            out[corpus_key] = 100 - i
        new_res[query_key] = out
    return new_res

Using this code, it output correct scores for ndcg, recall.. and so on.