terrierteam / pyterrier_colbert

81 stars 35 forks source link

Colbert PRF as a textual reranker #62

Open cmacdonald opened 1 year ago

cmacdonald commented 1 year ago

Maik Frobe requested Colbert prf as a textual reranker.

I think the code should look like this:

colbert = ColBERTModelOnlyFactory(checkpoint)
bm25 = pt.BatchRetrieve(sparse_index, wmodel='BM25', metadata=['docno', 'text'])
cprf_reranker = (
    bm25 
    >> colbert.text_encoder() 
    >> ColbertPRF(colbert, k=64, fb_embs=10, beta=1, fb_docs=10, return_docs=True) 
    >> colbert.scorer()
)

but: The only thing the index is used for is the token-level IDF, so we'd need to work around that... https://github.com/terrierteam/pyterrier_colbert/blob/main/pyterrier_colbert/ranking.py#L1020-L1024

Cc/ @seanmacavaney

Xiao0728 commented 2 months ago

I have used the pipeline BM25>>ColBERT-PRF

from sklearn.cluster import KMeans
from pyterrier.transformer import TransformerBase
import pandas as pd
class ColBERTPRF_docencoded(TransformerBase):
    def __init__(self, k, exp_terms, beta=1, r = 42, mean_cos_weight=False, idf_weight=False,ictf_weight = False, probIDF_weight=False, return_docs = False, fb_docs=10, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.exp_terms = exp_terms
        self.beta = beta
        self.mean_cos_weight = mean_cos_weight
        self.idf_weight = idf_weight
        self.probIDF_weight = probIDF_weight
        self.return_docs = return_docs
        self.fb_docs = fb_docs
        self.r = r
        self.ictf_weight = ictf_weight
        assert self.k > self.exp_terms ,"exp_terms should be smaller than number of clusters"
    def _get_prf_embs(self, df, num_docs):
#         return torch.cat(df.head(num_docs).doc_embs.values)
        return torch.cat((df.head(num_docs).doc_embs.values).tolist(),dim=0)

    def transform_query(self, topic_and_res):
        topic_and_res = topic_and_res.sort_values('rank')
#         prf_embs = torch.cat((topic_and_res.head(self.fb_docs).doc_embs.values).tolist(),dim=0)      
        prf_embs = self._get_prf_embs(topic_and_res, self.fb_docs)
#         prf_embs = torch.cat([pytcolbert.rrm.get_embedding(docid) for docid in topic_and_res.head(self.fb_docs).docid.values])

        kmn =  KMeans(self.k, random_state=self.r)
        kmn.fit(prf_embs)

        emb_and_score = []
        for cluster in range(self.k):
            # take the centroid, needs to be the float32.
            centroid = np.float32( kmn.cluster_centers_[cluster] )
#             with open('centroid.pickle', 'wb') as handle:
#                 pickle.dump(centroid, handle)
            tok2freq = get_nearest_tokens_for_emb(fnt, centroid)
            if len(tok2freq) == 0:
                continue
            most_likely_tok = max(tok2freq, key=tok2freq.get)
            tid = fnt.inference.query_tokenizer.tok.convert_tokens_to_ids(most_likely_tok)

            if self.mean_cos_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, id2meancos[tid])) # meanCos score without normalisation

            elif self.idf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, idfdict[tid]) ) # idf score without normalisation
#                 emb_and_score.append( (centroid, most_likely_tok, tid, idfGN[tid]) )  # idf score with global normalisation
            elif self.ictf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, ictfdict[tid]) )
            elif self.probIDF_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, probIDF[tid]) )  # probIDF score without normalisation
#                 emb_and_score.append( (centroid, most_likely_tok, tid, probIDFGN[tid]) )  # probIDFGN is probIDF score with global normalisation

        sorted_by_second = sorted(emb_and_score, key=lambda tup: -tup[3])

        toks=[]
        scores=[]
        exp_embds = []
        for i in range(min(self.exp_terms, len(sorted_by_second))):
            emb, tok, tid, score = sorted_by_second[i]
            toks.append(tok)

#             score = score/sorted_by_second[0][3]# normalisation by the largest ==> per-query normalisation

            scores.append(score)
            exp_embds.append(emb)

        first_row = topic_and_res.iloc[0]
        newemb = torch.cat([
            first_row.query_embs, 
            torch.Tensor(exp_embds)])
        # apply weighting to the query embeddings
        if self.mean_cos_weight or self.idf_weight or self.probIDF_weight or self.ictf_weight:
            # we are using mean_cos weighting?
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                self.beta * torch.Tensor(scores)]
            )
        else:
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                torch.full(self.exp_terms, self.beta)]
            )

        rtr = pd.DataFrame([
            [first_row.qid, 
             first_row.docno,
             first_row.query, 
#              first_row.doc_embs,
             newemb, 
             toks, 
             weights ]], columns=["qid","docno", "query", "query_embs","query_toks", "query_weights"])
        return rtr

#         ["qid","query",'docno','query_toks','query_embs']
    def transform(self, topics_and_docs):
        # some validation of the input
        required = ["qid", "query", "docid","docno", "query_embs"]
        for col in required:
            assert col in topics_and_docs.columns
        #restore the docid column if missing
        if "docid" not in topics_and_docs:
#             topics_and_docs["docid"] = topics_and_docs.docno.astype("int").values
            topics_and_docs["docid"] = topics_and_docs.docid.astype("int").values
        rtr = []
        for qid, res in topics_and_docs.groupby("qid"):
            new_query_df = self.transform_query(res)     
            if self.return_docs:
                new_query_df = res[["qid", "docno", "docid","doc_embs"]].merge(new_query_df, on=["qid"])

                new_query_df = new_query_df.rename(columns={'docno_x':'docno'})
            rtr.append(new_query_df)
        return pd.concat(rtr)

The experiment is run as follows,

pipeE2E_psg = pytcolbert.query_encoder() >> BM25 >> pt.text.sliding(prepend_title=False ) >> doc_encoder(pytcolbert ) >> scorer(pytcolbert)
pipePRF_rerank = pipeE2E_psg >> ColBERTPRF_docencoded(k=24, exp_terms=10, idf_weight=True, beta=1,fb_docs=3,return_docs=True)
bm25_prf_rerank = (pipePRF_rerank >> scorer(pytcolbert)>>pt.text.max_passage())%1000