AnswerDotAI / rerankers

A lightweight, low-dependency, unified API to use all common reranking and cross-encoder models.
Apache License 2.0
1.04k stars 57 forks source link

using Reranker in a multithreaded process issues Already borrowed Runtime Exeption #42

Open sam-bercovici opened 1 week ago

sam-bercovici commented 1 week ago

I am using colbert.

see: https://github.com/huggingface/tokenizers/issues/537

I suggest you allow to pass tokenizer_kwargs and model_kewargs to the Reranker factory class which will pass it forward.

follows an example on how to modify the ColBERTRanker ini

I marked the modification with ## change

    def __init__(
        self,
        model_name: str,
        batch_size: int = 32,
        dtype: Optional[Union[str, torch.dtype]] = None,
        device: Optional[Union[str, torch.device]] = None,
        verbose: int = 1,
        query_token: str = "[unused0]",
        document_token: str = "[unused1]",
        **kwargs, ## change
    ):
        self.verbose = verbose
        self.device = get_device(device, self.verbose)
        self.dtype = get_dtype(dtype, self.device, self.verbose)
        self.batch_size = batch_size
        vprint(
            f"Loading model {model_name}, this might take a while...",
            self.verbose,
        )
        tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) ## change
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) ## change
        model_kwargs = kwargs.get("model_kwargs", {}) ## change
        self.model = (
            ColBERTModel.from_pretrained(model_name, **model_kwargs) ## change
            .to(self.device)
            .to(self.dtype)
        )
        self.model.eval()
        self.query_max_length = 32  # Lower bound
        self.doc_max_length = (
            self.model.config.max_position_embeddings - 2
        )  # Upper bound
        self.query_token_id: int = self.tokenizer.convert_tokens_to_ids(query_token)  # type: ignore
        self.document_token_id: int = self.tokenizer.convert_tokens_to_ids(
            document_token
        )  # type: ignore
        self.normalize = True
bclavie commented 4 days ago

Thanks for flagging! Would you be willing to submit your proposed changes as a PR? I'm happy with this logic being added to handle various kwargs situations!

sam-bercovici commented 3 days ago

Thanks for flagging! Would you be willing to submit your proposed changes as a PR? I'm happy with this logic being added to handle various kwargs situations!

Sure. I will try to find a couple of hours to do so in the next week or so.