This is the code for our paper "BMRetriever: Tuning Large Language Models as Better Biomedical Text Retrievers".
Paper scores not matching #1

Closed kamalkraj closed 1 week ago

kamalkraj commented 1 week ago

script used for eval

from torch import Tensor

from typing import List, Dict, Union, Tuple
import numpy as np
import logging

from tqdm import tqdm

import torch

from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct_query(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

def get_detailed_instruct_passage(passage: str) -> str:
    return f'Represent this passage\npassage: {passage}'

logger = logging.getLogger(__name__)

class SentenceBERT:
    def __init__(self, model_path: Union[str, Tuple] = "BMRetriever/BMRetriever-7B", sep: str = " ", **kwargs):
        self.sep = sep
        self.task = 'Given a scientific claim, retrieve documents that support or refute the claim'

        self.model = AutoModel.from_pretrained(model_path, device_map="auto")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.max_length = 512

    def encode(self, input_texts, **kwargs):
        # Tokenize the input texts
        embeddings = []
        with torch.no_grad():
            for input_text in tqdm(input_texts):
                batch_dict = self.tokenizer(input_text, 
                outputs = self.model(**batch_dict)
                embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = torch.stack(embeddings)
        return embeddings

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        queries = [get_detailed_instruct_query(self.task, query) for query in queries]
        return self.encode(queries, batch_size=batch_size, **kwargs)

    def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        if type(corpus) is dict:
            sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        sentences = [get_detailed_instruct_passage(passage) for passage in sentences]
        return self.encode(sentences, batch_size=batch_size, **kwargs)
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from bm_embedding import SentenceBERT

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

#### Load the SBERT model and retrieve using cosine-similarity
model = DRES(SentenceBERT(model_path="BMRetriever/BMRetriever-2B"), batch_size=1)
retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity
results = retriever.retrieve(corpus, queries)

#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K  where k = [1,3,5,10,100,1000] 
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)


2024-06-23 11:06:57 - Loaded 5183 TEST Documents.
2024-06-23 11:06:57 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities...', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.'}
In the paper, the reported score is 0.760. Using the above script only got NDCG@10: 0.3686

What is missing, please help. Thanks


kamalkraj commented 1 week ago


2024-06-24 03:43:32 - Loading Corpus...
100%|███████████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [00:00<00:00, 173629.26it/s]
2024-06-24 03:43:32 - Loaded 5183 TEST Documents.
2024-06-24 03:43:32 - Loaded 5183 TEST Documents.
2024-06-24 03:43:32 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter...', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.'}
2024-06-24 03:43:32 - Loading Queries...
2024-06-24 03:43:32 - Loaded 300 TEST Queries.
2024-06-24 03:43:32 - Query Example: 0-dimensional biomaterials show inductive properties.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-06-24 03:43:42 - Encoding Queries...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:07<00:00, 40.87it/s]
2024-06-24 03:43:50 - Sorting Corpus by document length (Longest first)...
2024-06-24 03:43:50 - Encoding Corpus in batches... Warning: This might take a while!
2024-06-24 03:43:50 - Scoring Function: Dot Product (dot)
2024-06-24 03:43:50 - Encoding Batch 1/1...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [02:03<00:00, 42.09it/s]
2024-06-24 03:45:53 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-24 03:45:53 - 

2024-06-24 03:45:53 - NDCG@1: 0.3967
2024-06-24 03:45:53 - NDCG@3: 0.4353
2024-06-24 03:45:53 - NDCG@5: 0.4560
2024-06-24 03:45:53 - NDCG@10: 0.4717
2024-06-24 03:45:53 - NDCG@100: 0.5002
2024-06-24 03:45:53 - NDCG@1000: 0.5204
2024-06-24 03:45:53 - 

2024-06-24 03:45:53 - MAP@1: 0.3807
2024-06-24 03:45:53 - MAP@3: 0.4191
2024-06-24 03:45:53 - MAP@5: 0.4315
2024-06-24 03:45:53 - MAP@10: 0.4386
2024-06-24 03:45:53 - MAP@100: 0.4448
2024-06-24 03:45:53 - MAP@1000: 0.4455
2024-06-24 03:45:53 - 

2024-06-24 03:45:53 - Recall@1: 0.3807
2024-06-24 03:45:53 - Recall@3: 0.4634
2024-06-24 03:45:53 - Recall@5: 0.5137
2024-06-24 03:45:53 - Recall@10: 0.5603
2024-06-24 03:45:53 - Recall@100: 0.6879
2024-06-24 03:45:53 - Recall@1000: 0.8497
2024-06-24 03:45:53 - 

2024-06-24 03:45:53 - P@1: 0.3967
2024-06-24 03:45:53 - P@3: 0.1678
2024-06-24 03:45:53 - P@5: 0.1140
2024-06-24 03:45:53 - P@10: 0.0637
2024-06-24 03:45:53 - P@100: 0.0079
2024-06-24 03:45:53 - P@1000: 0.0010


2024-06-24 03:46:49 - Loading Corpus...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [00:00<00:00, 176718.92it/s]
2024-06-24 03:46:49 - Loaded 5183 TEST Documents.
2024-06-24 03:46:49 - Loaded 5183 TEST Documents.
2024-06-24 03:46:49 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter...', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.'}
2024-06-24 03:46:49 - Loading Queries...
2024-06-24 03:46:49 - Loaded 300 TEST Queries.
2024-06-24 03:46:49 - Query Example: 0-dimensional biomaterials show inductive properties.
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████| 753/753 [00:00<00:00, 241kB/s]
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████| 3.64G/3.64G [01:06<00:00, 54.9MB/s]
tokenizer_config.json: 100%|████████████████████████████████████████████████████████████████████████| 4.79k/4.79k [00:00<00:00, 2.56MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████| 2.31M/2.31M [00:00<00:00, 10.3MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-06-24 03:47:57 - Encoding Queries...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:05<00:00, 59.54it/s]
2024-06-24 03:48:02 - Sorting Corpus by document length (Longest first)...
2024-06-24 03:48:02 - Encoding Corpus in batches... Warning: This might take a while!
2024-06-24 03:48:02 - Scoring Function: Dot Product (dot)
2024-06-24 03:48:02 - Encoding Batch 1/1...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [01:40<00:00, 51.59it/s]
2024-06-24 03:49:43 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-24 03:49:43 - 

2024-06-24 03:49:43 - NDCG@1: 0.0100
2024-06-24 03:49:43 - NDCG@3: 0.0221
2024-06-24 03:49:43 - NDCG@5: 0.0285
2024-06-24 03:49:43 - NDCG@10: 0.0402
2024-06-24 03:49:43 - NDCG@100: 0.0957
2024-06-24 03:49:43 - NDCG@1000: 0.1494
2024-06-24 03:49:43 - 

2024-06-24 03:49:43 - MAP@1: 0.0100
2024-06-24 03:49:43 - MAP@3: 0.0183
2024-06-24 03:49:43 - MAP@5: 0.0218
2024-06-24 03:49:43 - MAP@10: 0.0268
2024-06-24 03:49:43 - MAP@100: 0.0365
2024-06-24 03:49:43 - MAP@1000: 0.0382
2024-06-24 03:49:43 - 

2024-06-24 03:49:43 - Recall@1: 0.0100
2024-06-24 03:49:43 - Recall@3: 0.0333
2024-06-24 03:49:43 - Recall@5: 0.0483
2024-06-24 03:49:43 - Recall@10: 0.0828
2024-06-24 03:49:43 - Recall@100: 0.3560
2024-06-24 03:49:43 - Recall@1000: 0.7876
2024-06-24 03:49:43 - 

2024-06-24 03:49:43 - P@1: 0.0100
2024-06-24 03:49:43 - P@3: 0.0111
2024-06-24 03:49:43 - P@5: 0.0100
2024-06-24 03:49:43 - P@10: 0.0087
2024-06-24 03:49:43 - P@100: 0.0039
2024-06-24 03:49:43 - P@1000: 0.0009


2024-06-23 11:06:57 - Loaded 5183 TEST Documents.
2024-06-24 03:52:37 - Loaded 5183 TEST Documents.
2024-06-24 03:52:37 - Loaded 5183 TEST Documents.
2024-06-24 03:52:37 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter...', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.'}
2024-06-24 03:52:37 - Loading Queries...
2024-06-24 03:52:37 - Loaded 300 TEST Queries.
2024-06-24 03:52:37 - Query Example: 0-dimensional biomaterials show inductive properties.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00,  1.86it/s]
2024-06-24 03:52:41 - Encoding Queries...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:21<00:00, 14.28it/s]
2024-06-24 03:53:02 - Sorting Corpus by document length (Longest first)...
2024-06-24 03:53:02 - Encoding Corpus in batches... Warning: This might take a while!
2024-06-24 03:53:02 - Scoring Function: Dot Product (dot)
2024-06-24 03:53:02 - Encoding Batch 1/1...
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [10:01<00:00,  8.61it/s]
2024-06-24 04:03:04 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-06-24 04:03:04 - 

2024-06-24 04:03:04 - NDCG@1: 0.3533
2024-06-24 04:03:04 - NDCG@3: 0.4226
2024-06-24 04:03:04 - NDCG@5: 0.4416
2024-06-24 04:03:04 - NDCG@10: 0.4601
2024-06-24 04:03:04 - NDCG@100: 0.4812
2024-06-24 04:03:04 - NDCG@1000: 0.4902
2024-06-24 04:03:04 - 

2024-06-24 04:03:04 - MAP@1: 0.3291
2024-06-24 04:03:04 - MAP@3: 0.3942
2024-06-24 04:03:04 - MAP@5: 0.4060
2024-06-24 04:03:04 - MAP@10: 0.4145
2024-06-24 04:03:04 - MAP@100: 0.4191
2024-06-24 04:03:04 - MAP@1000: 0.4194
2024-06-24 04:03:04 - 

2024-06-24 04:03:04 - Recall@1: 0.3291
2024-06-24 04:03:04 - Recall@3: 0.4764
2024-06-24 04:03:04 - Recall@5: 0.5243
2024-06-24 04:03:04 - Recall@10: 0.5780
2024-06-24 04:03:04 - Recall@100: 0.6748
2024-06-24 04:03:04 - Recall@1000: 0.7459
2024-06-24 04:03:04 - 

2024-06-24 04:03:04 - P@1: 0.3533
2024-06-24 04:03:04 - P@3: 0.1744
2024-06-24 04:03:04 - P@5: 0.1160
2024-06-24 04:03:04 - P@10: 0.0650
2024-06-24 04:03:04 - P@100: 0.0076
2024-06-24 04:03:04 - P@1000: 0.0008
ritaranx commented 1 week ago

Hello, thanks for pointing this out. Upon further investigation, there are some differences between your implementation and our implementation:

Additionally, we've uploaded a demo file eval.py to our GitHub repository. This script has been tested on a different machine and can reproduce the results with a very minimal discrepancy (within 0.001).

Let us know if you still have difficulties in reproducing our results.



kamalkraj commented 1 week ago
