AnswerDotAI / RAGatouille

Easily use and train state of the art late-interaction retrieval methods (ColBERT) in any RAG pipeline. Designed for modularity and ease-of-use, backed by research.
Apache License 2.0
3.07k stars 210 forks source link

`RAG.search` is not thread-safe #262

Open chandlj opened 2 weeks ago

chandlj commented 2 weeks ago

From the digging that I've done, it appears that the following code is not thread safe (using Langchain):

RAG = RAGPretrainedModel.from_index(index_path)
retriever = RAG.as_langchain_retriever(k=k)

chain = RunnablePassthrough.assign(
    passages=itemgetter("query") | retriever
) | prompt | llm

await chain.abatch([{"query": ...}, {"query": ...}, ...])

This results in the following error (truncated for security, should only contain relevant files):

Traceback (most recent call last):
  ...
  File ".../lib/python3.11/site-packages/ragatouille/RAGPretrainedModel.py", line 315, in search
    return self.model.search(
           ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/ragatouille/models/colbert.py", line 394, in search
    results = self.model_index.search(
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/ragatouille/models/index.py", line 343, in search
    if k > (32 * self.searcher.config.ncells):
            ~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'

The culprit, I believe, is in the ModelIndex.search function:

class PLAIDModelIndex(ModelIndex):
    ...
    def search(self, ...):
        # This will result in a race condition when run in parallel on multiple threads!!!!
        if self.searcher is None or force_reload:
            self._load_searcher(
                checkpoint,
                collection,
                index_name,
                force_fast,
            )
        assert self.searcher is not None
        ...

My intuition is that this would be a pretty common use-case. For reference, the official ColBERT implementation of server.py initializes the Searcher at the beginning before the API calls. I think this searcher should be initialized before a call to RAG.search is ever made to prevent this race condition, OR there should be a batch function on RAG.as_langchain_retriever.