terrier-org / pyterrier

A Python framework for performing information retrieval experiments, building on http://terrier.org/
https://pyterrier.readthedocs.io/
Mozilla Public License 2.0
412 stars 65 forks source link

TransformerBase, custom transform function error in sparse retriever (bm25) #433

Closed bakingeol closed 6 months ago

bakingeol commented 6 months ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce Steps to reproduce the behavior:

  1. Which index
  2. Which retrieval
  3. What pipeline
  4. What was the dataframe output
  5. See error

Expected behavior A clear and concise description of what you expected to happen.

Documentation and Issues

Screenshots If applicable, add screenshots to help explain your problem.

Additional context Add any other context about the problem here.

Hi, thank you for great repo!

When i try to input custom rewrited query, i encounter the error It is really weird to me, because when i run my dense retrieval code it works,,,

Here is the error message, my dataset format, and simple code (is my best..). and i also add dense retriever code

please help.... this is important to me.....

pytrec-eval==0.5 pytrec-eval-terrier==0.5.6 pyserini==0.22.1 pyterrier-sentence-transformers==0.2.2

error message

Traceback (most recent call last):
  File "experiment_sparse_retr.py", line 129, in <module>
    main(args)
  File "experiment_sparse_retr.py", line 104, in main
    result=pt.Experiment(
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pyterrier/pipelines.py", line 472, in Experiment
    time, evalMeasuresDict = _run_and_evaluate(
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pyterrier/pipelines.py", line 193, in _run_and_evaluate
    res = system.transform(topics)
  File "/home/baekig/practice/utilss.py", line 12, in transform
    import pdb;pdb.set_trace()
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pyterrier/ops.py", line 335, in transform
    topics = m.transform(topics)
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pyterrier/apply_base.py", line 217, in transform
    outputRes["query"] = outputRes.apply(fn, axis=1)
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pandas/core/frame.py", line 3941, in __setitem__
    self._set_item_frame_value(key, value)
  File "/home/baekig/.conda/envs/dense_train2/lib/python3.8/site-packages/pandas/core/frame.py", line 4071, in _set_item_frame_value
    raise ValueError("Columns must be same length as key")
ValueError: Columns must be same length as key

dataset format (trec-covid_concat.csv) image

error code.

from pyterrier.measures import *
import pyterrier as pt
from pyterrier.transformer import TransformerBase
import argparse

from utilss import RewriteQuery, cleanStrDF
from datasets import load_dataset

if not pt.started():
    pt.init(mem=32000)
import pandas as pd
def main(args):
    path = args.index_path
    index = pt.IndexFactory.of(path)    
    dataset = pt.get_dataset(f'irds:{args.dataset}')
    topics = dataset.get_topics()
    querys = dataset.get_qrels()

    _bm25 = pt.BatchRetrieve(index, wmodel='BM25', verbose=True, metadata=["docno"])
    bm25_clean = pt.apply.query(cleanStrDF) >> _bm25    

    if args.cot:
        path = args.data_path

        _topics = pd.read_csv(path)
        _topics.columns
        _topics.loc[:,'cot'] = _topics.loc[:,'query']+'. '+_topics.loc[:,'query']+'. '+_topics.loc[:,'query']+'. '+_topics.loc[:,'cot']
        new_topic=pd.concat([topics, _topics['cot']], axis=1)
        bm25 = RewriteQuery(bm25_clean, 'cot')

    result=pt.Experiment(
        [bm25],
        new_topic,
        querys,
        eval_metrics=["ndcg", "ndcg_cut_10", "recip_rank", "recall_10", "recall_50" , "recall_1000"],
        names=[f'{args.exp_name}']
    )
    print(result)

if __name__ =="__main__":
    parser=argparse.ArgumentParser()
    parser.add_argument("--exp_name", default='sample')
    parser.add_argument("--data_path", default='') 
    parser.add_argument("--index_path", default='')
    parser.add_argument("--dataset", default='') 

    parser.add_argument("--baseline", action='store_true') 
    parser.add_argument("--q2d_base", action='store_true') 
    parser.add_argument("--cot", action='store_true') 

    parser.add_argument("--query_num", default=3, type=int) # run
    args = parser.parse_args()
    main(args)

command - sparse retriever

python experiment_sparse_retr.py --dataset beir/trec-covid --data_path trec-covid_concat.csv --index_path sparse_index/beir_trec-covid/data.properties --exp_name cot --cot

utilss.py code

import re
from pyterrier.transformer import TransformerBase
class RewriteQuery(TransformerBase):
    def __init__(self, inner_pipe, rewritename):
        self.inner_pipe = inner_pipe
        self.rewritename = rewritename
    def transform(self, topics):
        topicsNew = topics.copy()
        topicsNew["query"] = topicsNew[self.rewritename]
        import pdb;pdb.set_trace()
        res = self.inner_pipe.transform(topicsNew)

        return res
def cleanStr(text):
    # text = text["query"]
    text = text.replace('\W', ' ')
    text = text.replace('?', '')
    text = text.replace("á", 'a')
    text = text.replace("é", 'e')
    text = text.replace("ö", 'o')
    text = text.replace("Č", 'C')
    text = text.replace("ć", 'c')
    text = text.replace("ó", 'o')
    text = text.replace("ă", 'a')
    text = text.replace("ä", 'a')
    text = text.replace("ü", 'u')
    text = text.replace("ā", 'a')
    text = text.replace("í", 'i')
    text = text.replace("ÿ", 'y')
    text = re.sub('[^0-9a-zA-Z]+', ' ', text)
    return text

def cleanStrDF(q):
    text = q["query"]
    text = text.replace('\W', ' ')
    text = text.replace('?', '')
    text = text.replace("á", 'a')
    text = text.replace("é", 'e')
    text = text.replace("ö", 'o')
    text = text.replace("Č", 'C')
    text = text.replace("ć", 'c')
    text = text.replace("ó", 'o')
    text = text.replace("ă", 'a')
    text = text.replace("ä", 'a')
    text = text.replace("ü", 'u')
    text = text.replace("ā", 'a')
    text = text.replace("í", 'i')
    text = text.replace("ÿ", 'y')
    text = re.sub('[^0-9a-zA-Z]+', ' ', text)
    return q

dense retrieval code (it works!!)

from pyterrier.measures import *
import pyterrier as pt
import argparse
import pandas as pd
if not pt.started():
    pt.init(mem=32000)
from pyterrier_sentence_transformers import (
    SentenceTransformersRetriever,
    SentenceTransformerConfig
)
from utilss import RewriteQuery, cleanStrDF
import springs as sp
from pathlib import Path
from typing import Optional

def swap_columns(df):
    swap_list=df.columns.tolist()
    for num,i in enumerate(df.columns):
        if i =='query_id':
            swap_list[num] = 'qid'
    df.columns = swap_list

def main(args):
    DATASET = args.dataset 
    NEU_MODEL_NAME = args.exp_name
    dataset = pt.get_dataset(f'irds:{DATASET}')

    _df = dataset.get_topics()
    queries = dataset.get_qrels()

    if args.cot:
        _topics = pd.read_csv(args.data_path)

        swap_columns(_topics)
        _topics['qid']=_topics['qid'].apply(lambda x: str(x))
        _topics.loc[:,'cot'] = _topics.loc[:,'query']+'. '+_topics.loc[:,'query']+'. '+_topics.loc[:,'query']+'. '+_topics.loc[:,'cot']

        df = pd.concat([_df, _topics.loc[:,'cot']], axis=1)

    class SentenceTransformerConfigWithDefaults(SentenceTransformerConfig):
        model_name_or_path: Optional[str] = None    
        index_path: Optional[str] = None
    index_root = Path('my_path') / DATASET.replace('/', '_')
    neu_index_path = index_root / NEU_MODEL_NAME.replace('/', '_')

    neu_retr = SentenceTransformersRetriever(
        model_name_or_path=NEU_MODEL_NAME, 
        index_path=str(neu_index_path)
    )
    if args.cot:
        neu_retr = RewriteQuery(neu_retr, 'cot')
        name = 'cot'

    eval_metrics = ["ndcg", "ndcg_cut_10", "recip_rank", "recall_10", "recall_50" , "recall_1000"]

    exp = pt.Experiment(
        [neu_retr],
        df,
        queries,
        names=[f'{name}'],
        eval_metrics=eval_metrics
    )

if __name__ =='__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('--exp_name',default='') 
    parser.add_argument('--data_path',default='')
    parser.add_argument('--dataset',default='msmarco-passage/dev')

    parser.add_argument('--cot',action='store_true')
    parser.add_argument('--query_num', default=3, type=int)

    args = parser.parse_args()
    main(args)

command -dense retriever

python experiment_dense_retr.py --cot --exp_name my_name/my_model_name --dataset beir/trec-covid --data_path trec-covid_concat.csv

thank you for reading

bakingeol commented 6 months ago

sorry, it's my mistake..

cmacdonald commented 6 months ago

Hi @bakingeol - for posterity, what was the solution?

bakingeol commented 6 months ago

HI @cmacdonald - thank you for your response. it is just typo error.. In function "cleanStrDF" have to return 'text'. haha...😂

cmacdonald commented 6 months ago

I'll see if we can catch this and provide a more explanatory error message.