Querent-ai / querent-python-research

Querent
https://querent.xyz
Other
2 stars 0 forks source link

Scaling out file level parallelism #251

Open Ansh5461 opened 7 months ago

saraswatpuneet commented 7 months ago
import asyncio
from asyncio import Queue, Semaphore
from typing import AsyncGenerator, List, Optional
from cachetools import LRUCache, cachedmethod
from querent.channel.channel_interface import ChannelCommandInterface
from querent.collectors.collector_base import Collector
from querent.common.types.collected_bytes import CollectedBytes
from querent.common.types.ingested_tokens import IngestedTokens
from querent.config.ingestor.ingestor_config import IngestorBackend
from querent.ingestors.base_ingestor import BaseIngestor
from querent.ingestors.factory_map import ingestor_factory_map
from querent.logging.logger import setup_logger

class IngestorFactoryManager:
    """
    Manages the factories for various types of ingestors, directing collected files
    to the appropriate ingestor based on file type. Supports asynchronous ingestion
    of files from multiple collectors with file-level parallelism.
    """

    def __init__(
        self,
        collectors: Optional[List[Collector]] = None,
        processors: Optional[List[AsyncProcessor]] = None,
        cache_size: int = 100,
        result_queue: Optional[Queue] = None,
        tokens_feader: Optional[ChannelCommandInterface] = None,
        max_concurrent_files: int = 10,
    ):
        self.collectors = collectors if collectors else []
        self.processors = processors if processors else []
        self.ingestor_factories = ingestor_factory_map
        self.file_caches = LRUCache(maxsize=cache_size)
        self.result_queue = result_queue
        self.tokens_feader = tokens_feader
        self.concurrent_file_semaphore = Semaphore(max_concurrent_files)
        self.logger = setup_logger(__name__, "IngestorFactoryManager")

    async def get_factory(self, file_extension: str) -> Optional[BaseIngestor]:
        """
        Retrieve the appropriate ingestor factory for the given file extension.

        Args:
            file_extension (str): The file extension to find a factory for.

        Returns:
            Optional[BaseIngestor]: The ingestor factory if found, None otherwise.
        """
        # Simplified for brevity; refer to your factory retrieval logic
        return self.ingestor_factories.get(file_extension.lower())

    @cachedmethod(cache=lambda self: self.file_caches)
    async def ingest_file_async(self, file_id: str):
        """
        Asynchronously ingest a file given its ID, utilizing the appropriate ingestor
        based on file extension. Limits the concurrency to prevent resource exhaustion.

        Args:
            file_id (str): The unique identifier of the file to be ingested.
        """
        async with self.concurrent_file_semaphore:
            collected_bytes_list = self.file_caches.pop(file_id, [])
            if not collected_bytes_list:
                return

            file_extension = collected_bytes_list[0].extension
            ingestor = await self.get_factory(file_extension)
            if not ingestor:
                self.logger.warning(f"Unsupported file extension: {file_extension}")
                return

            async def chunk_generator() -> AsyncGenerator[CollectedBytes, None]:
                for chunk in collected_bytes_list:
                    yield chunk

            await self.process_file(ingestor, chunk_generator(), file_id)

    async def process_file(self, ingestor: BaseIngestor, chunk_generator: AsyncGenerator[CollectedBytes, None], file_id: str):
        try:
            async for chunk_tokens in ingestor.ingest(chunk_generator):
                await self.handle_ingested_tokens(chunk_tokens)
        except Exception as e:
            self.logger.error(f"Error processing file {file_id}: {e}")

    async def handle_ingested_tokens(self, chunk_tokens: IngestedTokens):
        if self.result_queue is not None:
            await self.result_queue.put(chunk_tokens)
        if self.tokens_feader is not None:
            self.tokens_feader.send_tokens_in_rust(chunk_tokens)

    async def ingest_all_async(self):
        """
        Asynchronously ingest files from all registered collectors, with controlled concurrency
        for file processing.
        """
        tasks = [asyncio.create_task(self.ingest_collector_async(collector)) for collector in self.collectors]
        await asyncio.gather(*tasks)
        await self.finalize_ingestion()

    async def finalize_ingestion(self):
        if self.result_queue is not None:
            await self.result_queue.put(None)  # Signal the end of ingestion

# Additional methods like ingest_collector_async should follow the established patterns of error handling,
# asynchronous operation, and resource management as shown above.