run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.4k stars 4.68k forks source link

[Bug]: multi-worker IngestionPipeline does not work with caches. #14572

Open Falven opened 4 days ago

Falven commented 4 days ago

Bug Description

When creating an IngestionPipeline with num_workers, the pipeline divies up the work as follows:

with ProcessPoolExecutor(max_workers=num_workers) as p:
    node_batches = self._node_batcher(
        num_batches=num_workers, nodes=nodes_to_run
    )
    tasks = [
        loop.run_in_executor(
            p,
            partial(
                arun_transformations_wrapper,
                transformations=self.transformations,
                in_place=in_place,
                cache=self.cache if not self.disable_cache else None,
                cache_collection=cache_collection,
            ),
            batch,
        )
        for batch in node_batches
    ]
    result: List[List[BaseNode]] = await asyncio.gather(*tasks)
    nodes = reduce(lambda x, y: x + y, result, [])

The problem here is that this partial function needs to be serializable to be able to be used in another thread. There could be issues with some non-serializable transformations, but for my simple parsing and embedding case it works fine. The problem comes in because I am using a Redis IngestionCache. The cache itself is not serializable and therefore this will throw an Exception.

I think it's specifically the _redis_client that is not serializable.

IngestionCache(collection='llama_cache', cache=<llama_index.storage.kvstore.redis.base.RedisKVStore object at 0x304725d10>, nodes_key='nodes')

I think a potential workaround could be to pass the parameters needed to initialize the cache to each thread and have each threads re-initialize it's own cache object. This would require rewriting the IngestionPipeline to take in a cache_factory rather than the object itself.

Version

0.10.52

Steps to Reproduce

Just create an IngestionPipeline that uses a Redis IngestionCache:

embed_model = AzureOpenAIEmbedding(
    model=os.environ["AZURE_EMBEDDING_MODEL"],
    deployment_name=os.environ["AZURE_EMBEDDING_DEPLOYMENT"],
    azure_endpoint=os.environ["AZURE_EMBEDDING_ENDPOINT"],
    api_key=os.environ["AZURE_EMBEDDING_API_KEY"],
    api_version=os.environ["AZURE_EMBEDDING_API_VERSION"],
)

ingestion_cache = IngestionCache(
    cache=RedisKVStore(
        redis_client=Redis(**redis_kwargs),
        async_redis_client=AsyncRedis(**redis_kwargs),
    )
)

pipeline = IngestionPipeline(
    vector_store=vector_store,
    docstore=document_store,
    transformations=[Settings.embed_model],
    cache=ingestion_cache,
)

await pipeline.arun(
    documents=documents, num_workers=multiprocessing.cpu_count()
)

Relevant Logs/Tracbacks

No response

logan-markewich commented 4 days ago

This is why I was hesitant to add multithreading to the ingestion pipeline -- a ton of objects are not pickle friendly 😅 I'm not sure what the fix would look like here. Maybe if the transformations and cache were accessed outside of the partial. Something like (feels hacky, but could work)

def get_arun_transformations_wrapper(transformations, cache):
  def arun_transformations_wrapper(...):
    <use the transformations and cache>

  return arun_transformations_wrapper

...

wrapper_fn = get_arun_transformations_wrapper(self.transformations, self.cache)
tasks = [
        loop.run_in_executor(
            p,
            wrapper_fn,
            batch,
        )
        for batch in node_batches
    ]
Falven commented 4 days ago

This is why I was hesitant to add multithreading to the ingestion pipeline -- a ton of objects are not pickle friendly 😅 I'm not sure what the fix would look like here. Maybe if the transformations and cache were accessed outside of the partial. Something like (feels hacky, but could work)

def get_arun_transformations_wrapper(transformations, cache):
  def arun_transformations_wrapper(...):
    <use the transformations and cache>

  return arun_transformations_wrapper

...

wrapper_fn = get_arun_transformations_wrapper(self.transformations, self.cache)
tasks = [
        loop.run_in_executor(
            p,
            wrapper_fn,
            batch,
        )
        for batch in node_batches
    ]

I don't think this will work because the local functions and objects defined within functions cannot be pickled. 😕

I think a potential workaround could be to pass the parameters needed to initialize the cache to each thread and have each threads re-initialize it's own cache object. This would require rewriting the IngestionPipeline to take in a cache_factory rather than the instance itself. I'll try this out and see how it goes. I think this would help with the cache situation, but the transformations would be a whole other problem.

logan-markewich commented 4 days ago

I feel like it will work actually... let me try the method above :)

Falven commented 4 days ago

I feel like it will work actually... let me try the method above :)

[2024-07-04T17:02:56.074Z] AttributeError: Can't pickle local object 'get_arun_transformations_wrapper.<locals>.arun_transformations_wrapper'
[2024-07-04T17:02:56.075Z] An unexpected error occurred: Can't pickle local object 'get_arun_transformations_wrapper.<locals>.arun_transformations_wrapper'
[2024-07-04T17:02:56.075Z] concurrent.futures.process._RemoteTraceback: 
[2024-07-04T17:02:56.075Z] """
[2024-07-04T17:02:56.075Z] Traceback (most recent call last):
[2024-07-04T17:02:56.075Z]   File "/Users/falven/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
[2024-07-04T17:02:56.075Z]     obj = _ForkingPickler.dumps(obj)
[2024-07-04T17:02:56.075Z]           ^^^^^^^^^^^^^^^^^^^^^^^^^^
[2024-07-04T17:02:56.075Z]   File "/Users/falven/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
[2024-07-04T17:02:56.075Z]     cls(buf, protocol).dump(obj)
[2024-07-04T17:02:56.075Z] AttributeError: Can't pickle local object 'get_arun_transformations_wrapper.<locals>.arun_transformations_wrapper'
[2024-07-04T17:02:56.075Z] """
[2024-07-04T17:02:56.075Z] 
[2024-07-04T17:02:56.075Z] The above exception was the direct cause of the following exception:
[2024-07-04T17:02:56.075Z] 
[2024-07-04T17:02:56.075Z] Traceback (most recent call last):
[2024-07-04T17:02:56.075Z]   File "/Users/falven/Source/AJG/src/ingestion/ingest/__init__.py", line 116, in ingest
[2024-07-04T17:02:56.075Z]     await asyncio.gather(*ingestion_tasks)
[2024-07-04T17:02:56.075Z]   File "/Users/falven/Source/AJG/src/ingestion/.venv/lib/python3.11/site-packages/llama_index/core/instrumentation/dispatcher.py", line 255, in async_wrapper
[2024-07-04T17:02:56.075Z]     result = await func(*args, **kwargs)
[2024-07-04T17:02:56.075Z]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2024-07-04T17:02:56.075Z]   File "/Users/falven/Source/AJG/src/ingestion/.venv/lib/python3.11/site-packages/llama_index/core/ingestion/pipeline.py", line 752, in arun
[2024-07-04T17:02:56.075Z]     result: List[List[BaseNode]] = await asyncio.gather(*tasks)
[2024-07-04T17:02:56.075Z]                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2024-07-04T17:02:56.075Z]   File "/Users/falven/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
[2024-07-04T17:02:56.075Z]     obj = _ForkingPickler.dumps(obj)
[2024-07-04T17:02:56.075Z]           ^^^^^^^^^^^^^^^^^^^^^^^^^^
[2024-07-04T17:02:56.075Z]   File "/Users/falven/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
[2024-07-04T17:02:56.075Z]     cls(buf, protocol).dump(obj)
[2024-07-04T17:02:56.075Z] AttributeError: Can't pickle local object 'get_arun_transformations_wrapper.<locals>.arun_transformations_wrapper'

Can't pickle local object 😕

logan-markewich commented 4 days ago

lol just hit that too. Hmm

Falven commented 4 days ago

lol just hit that too. Hmm

But even the way I proposed probably wouldn't work because of the variety of caches. We would essentially need to pass not just the type of your IngestionCache and the kwargs to it but also the type of your IngestionCache's KVStore and kwargs to it all the way down to the Redis client. This would work, but it's so messy.

logan-markewich commented 4 days ago

I agree. I'm pretty sure there's a hacky way to do this and get around the pickling... messing around with a few things