UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
14.95k stars 2.44k forks source link

InformationRetrievalEvaluator finding the absolute rank of first relevant_docs found or calculating mrr per group with evaluator #2814

Open maayansharon10 opened 2 months ago

maayansharon10 commented 2 months ago

Hello,

I'm using all-mpnet-base-v2 model with InformationRetrievalEvaluator . My queries and relevant docs are assosiated with groups, my objective is calcualting mrr@k and recall@k per group.

The results yields the metrics recall@k and mrr@k with cosine or dot product, but I can't find a way to get the actual rank of the sentence that this recall of mrr is calculated with, aka the rank of the best scored document in relevant_docs out of the corpus in regard to a query, per k.

The only way I found is to run the inference step again on the data and calculate it myself, but that seems like a computational waste because it is already being calculated in the process.

Is there a way to do this more elegantly? If so, could you please direct me to an example? Many thanks

tomaarsen commented 2 months ago

Hello!

I think your best bet might be to create a subclass of the InformationRetrievalEvaluator which also outputs the rank of the best scored "true relevant" document. I think you'll get far by overriding the compute_metrics method:


from datasets import load_dataset, Dataset
import numpy as np
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformer
from typing import List
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)

class CustomInformationRetrievalEvaluator(InformationRetrievalEvaluator):
    def compute_metrics(self, queries_result_list: List[object]):
        # Init score computation values
        num_hits_at_k = {k: 0 for k in self.accuracy_at_k}
        precisions_at_k = {k: [] for k in self.precision_recall_at_k}
        recall_at_k = {k: [] for k in self.precision_recall_at_k}
        min_rank_at_k = {k: [] for k in self.precision_recall_at_k}
        MRR = {k: 0 for k in self.mrr_at_k}
        ndcg = {k: [] for k in self.ndcg_at_k}
        AveP_at_k = {k: [] for k in self.map_at_k}

        # Compute scores on results
        for query_itr in range(len(queries_result_list)):
            query_id = self.queries_ids[query_itr]

            # Sort scores
            top_hits = sorted(queries_result_list[query_itr], key=lambda x: x["score"], reverse=True)
            query_relevant_docs = self.relevant_docs[query_id]

            # Accuracy@k - We count the result correct, if at least one relevant doc is across the top-k documents
            for k_val in self.accuracy_at_k:
                for hit in top_hits[0:k_val]:
                    if hit["corpus_id"] in query_relevant_docs:
                        num_hits_at_k[k_val] += 1
                        break

            # Precision and Recall@k
            for k_val in self.precision_recall_at_k:
                num_correct = 0
                for hit in top_hits[0:k_val]:
                    if hit["corpus_id"] in query_relevant_docs:
                        num_correct += 1

                precisions_at_k[k_val].append(num_correct / k_val)
                recall_at_k[k_val].append(num_correct / len(query_relevant_docs))

            # CUSTOM: (Avg) Rank of first relevant doc
            for k_val in self.precision_recall_at_k:
                rank = k_val
                for idx, hit in enumerate(top_hits[:k_val]):
                    if hit["corpus_id"] in query_relevant_docs:
                        rank = idx

                min_rank_at_k[k_val].append(rank)

            # MRR@k
            for k_val in self.mrr_at_k:
                for rank, hit in enumerate(top_hits[0:k_val]):
                    if hit["corpus_id"] in query_relevant_docs:
                        MRR[k_val] += 1.0 / (rank + 1)
                        break

            # NDCG@k
            for k_val in self.ndcg_at_k:
                predicted_relevance = [
                    1 if top_hit["corpus_id"] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]
                ]
                true_relevances = [1] * len(query_relevant_docs)

                ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(
                    true_relevances, k_val
                )
                ndcg[k_val].append(ndcg_value)

            # MAP@k
            for k_val in self.map_at_k:
                num_correct = 0
                sum_precisions = 0

                for rank, hit in enumerate(top_hits[0:k_val]):
                    if hit["corpus_id"] in query_relevant_docs:
                        num_correct += 1
                        sum_precisions += num_correct / (rank + 1)

                avg_precision = sum_precisions / min(k_val, len(query_relevant_docs))
                AveP_at_k[k_val].append(avg_precision)

        # Compute averages
        for k in num_hits_at_k:
            num_hits_at_k[k] /= len(self.queries)

        for k in precisions_at_k:
            precisions_at_k[k] = np.mean(precisions_at_k[k])

        for k in recall_at_k:
            recall_at_k[k] = np.mean(recall_at_k[k])

        for k in min_rank_at_k:
            min_rank_at_k[k] = np.mean(min_rank_at_k[k])

        for k in ndcg:
            ndcg[k] = np.mean(ndcg[k])

        for k in MRR:
            MRR[k] /= len(self.queries)

        for k in AveP_at_k:
            AveP_at_k[k] = np.mean(AveP_at_k[k])

        return {
            "accuracy@k": num_hits_at_k,
            "precision@k": precisions_at_k,
            "recall@k": recall_at_k,
            "min_rank@k": min_rank_at_k, # <- Custom metric
            "ndcg@k": ndcg,
            "mrr@k": MRR,
            "map@k": AveP_at_k,
        }

    def output_scores(self, scores):
        super().output_scores(scores)

        for k in scores["min_rank@k"]:
            logger.info("min. rank@{}: {:.4f}".format(k, scores["min_rank@k"][k]))

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset = dataset.add_column("id", range(len(dataset)))
train_dataset: Dataset = dataset.select(range(99_000))
eval_dataset: Dataset = dataset.select(range(99_000, len(dataset))) # <- 1_231 samples

queries = dict(zip(eval_dataset["id"], eval_dataset["query"]))
corpus = {cid: dataset[cid]["answer"] for cid in range(10_000)} | {cid: dataset[cid]["answer"] for cid in eval_dataset["id"]}
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = CustomInformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="natural-questions-dev",
    batch_size=8,
)

model = SentenceTransformer("all-MiniLM-L6-v2")
dev_evaluator(model)
2024-07-08 18:11:42,911 - Queries: 1231
2024-07-08 18:11:42,911 - Corpus: 11231

2024-07-08 18:11:43,017 - Score-Function: cosine
2024-07-08 18:11:43,017 - Accuracy@1: 75.22%
2024-07-08 18:11:43,017 - Accuracy@3: 93.99%
2024-07-08 18:11:43,018 - Accuracy@5: 96.99%
2024-07-08 18:11:43,018 - Accuracy@10: 99.03%
2024-07-08 18:11:43,018 - Precision@1: 75.22%
2024-07-08 18:11:43,018 - Precision@3: 31.33%
2024-07-08 18:11:43,018 - Precision@5: 19.40%
2024-07-08 18:11:43,018 - Precision@10: 9.90%
2024-07-08 18:11:43,018 - Recall@1: 75.22%
2024-07-08 18:11:43,019 - Recall@3: 93.99%
2024-07-08 18:11:43,019 - Recall@5: 96.99%
2024-07-08 18:11:43,019 - Recall@10: 99.03%
2024-07-08 18:11:43,019 - MRR@10: 0.8479
2024-07-08 18:11:43,019 - NDCG@10: 0.8836
2024-07-08 18:11:43,019 - MAP@100: 0.8483
2024-07-08 18:11:43,019 - min. rank@1: 0.2478
2024-07-08 18:11:43,019 - min. rank@3: 0.4184
2024-07-08 18:11:43,020 - min. rank@5: 0.4850
2024-07-08 18:11:43,020 - min. rank@10: 0.5516
2024-07-08 18:11:43,020 - Score-Function: dot
2024-07-08 18:11:43,020 - Accuracy@1: 75.22%
2024-07-08 18:11:43,020 - Accuracy@3: 93.99%
2024-07-08 18:11:43,020 - Accuracy@5: 97.08%
2024-07-08 18:11:43,021 - Accuracy@10: 99.03%
2024-07-08 18:11:43,021 - Precision@1: 75.22%
2024-07-08 18:11:43,021 - Precision@3: 31.33%
2024-07-08 18:11:43,021 - Precision@5: 19.42%
2024-07-08 18:11:43,021 - Precision@10: 9.90%
2024-07-08 18:11:43,021 - Recall@1: 75.22%
2024-07-08 18:11:43,021 - Recall@3: 93.99%
2024-07-08 18:11:43,022 - Recall@5: 97.08%
2024-07-08 18:11:43,022 - Recall@10: 99.03%
2024-07-08 18:11:43,022 - MRR@10: 0.8479
2024-07-08 18:11:43,022 - NDCG@10: 0.8836
2024-07-08 18:11:43,022 - MAP@100: 0.8484
2024-07-08 18:11:43,022 - min. rank@1: 0.2478
2024-07-08 18:11:43,022 - min. rank@3: 0.4184
2024-07-08 18:11:43,023 - min. rank@5: 0.4833
2024-07-08 18:11:43,023 - min. rank@10: 0.5500

So, for this script, at average rank in the top-10 case is 0.55, which is rather good (i.e. low). These were the ranks:

[0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 6, 0, 0, 0, 0, 1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 1, 1, 0, 0, 3, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 2, 0, 0, 0, 2, 0, 5, 0, 1, 0, 0, 0, 0, 1, 3, 0, 1, 0, 0, 0, 2, 1, 1, 0, 1, 3, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 5, 3, 0, 1, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 3, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 10, 0, 1, 2, 6, 1, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 10, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 10, 0, 0, 1, 0, 1, 0, 0, 0, 2, 7, 1, 0, 0, 0, 0, 4, 2, 0, 1, 0, 1, 1, 0, 3, 0, 0, 0, 0, 0, 6, 0, 0, 2, 0, 1, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 1, 2, 3, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 7, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 7, 0, 0, 0, 0, 0, 0, 7, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 3, 2, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 1, 0, 0, 0, 1, 5, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 0, 0, 0, 2, 0, 3, 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 4, 0, 0, 1, 0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 2, 1, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, 2, 1, 0, 1, 0, 0, 0, 0, 0, 0, 3, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 3, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 1, 0, 0, 1, 2, 2, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 8, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 2, 0, 0, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 1, 7, 1, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 10, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 1, 0, 10, 0, 1, 0, 0, 0, 0, 5, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 2, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 3, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 2, 1, 0, 0, 5, 0, 0, 1, 3, 2, 1, 1, 1, 2, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 6, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0]

As you can see, usually the answer that indeed corresponds with the question was rank 0.

Something to consider is the edge case of "what if none of the true relevant documents were in the top k". In this script I've set the rank to k. You can also just update the script to not get results per k, but just get the mean minimum rank instead. Perhaps that's more useful - but it all depends on what you'd like. You can also compute 25th & 75th percentile, median, etc.

maayansharon10 commented 2 months ago

Thank you ! that's a good idea, very helpful appreciate it a lot.

For my other question, regarding calculating mrr@k (or the other param) per group - The evaluator object expect to find certain columns, and as far as I understand I cannot add a group column, or can I ? Not sure how can I solve this issue with overwriting the compute_metrics function, as I don't understand how to pass the group parameter per query to the evaluator.

currently I'm calculating it manually for each line in the db. trying to filter the dataset per group and then running the evaluator on the subset per group does not preduce the same results since the corpus is different. Any thoughts/ideas that maybe could help me with that?

Thank you again for your help, M