import asyncio
from asyncio import Queue, Semaphore
from typing import AsyncGenerator, List, Optional
from cachetools import LRUCache, cachedmethod
from querent.channel.channel_interface import ChannelCommandInterface
from querent.collectors.collector_base import Collector
from querent.common.types.collected_bytes import CollectedBytes
from querent.common.types.ingested_tokens import IngestedTokens
from querent.config.ingestor.ingestor_config import IngestorBackend
from querent.ingestors.base_ingestor import BaseIngestor
from querent.ingestors.factory_map import ingestor_factory_map
from querent.logging.logger import setup_logger
class IngestorFactoryManager:
"""
Manages the factories for various types of ingestors, directing collected files
to the appropriate ingestor based on file type. Supports asynchronous ingestion
of files from multiple collectors with file-level parallelism.
"""
def __init__(
self,
collectors: Optional[List[Collector]] = None,
processors: Optional[List[AsyncProcessor]] = None,
cache_size: int = 100,
result_queue: Optional[Queue] = None,
tokens_feader: Optional[ChannelCommandInterface] = None,
max_concurrent_files: int = 10,
):
self.collectors = collectors if collectors else []
self.processors = processors if processors else []
self.ingestor_factories = ingestor_factory_map
self.file_caches = LRUCache(maxsize=cache_size)
self.result_queue = result_queue
self.tokens_feader = tokens_feader
self.concurrent_file_semaphore = Semaphore(max_concurrent_files)
self.logger = setup_logger(__name__, "IngestorFactoryManager")
async def get_factory(self, file_extension: str) -> Optional[BaseIngestor]:
"""
Retrieve the appropriate ingestor factory for the given file extension.
Args:
file_extension (str): The file extension to find a factory for.
Returns:
Optional[BaseIngestor]: The ingestor factory if found, None otherwise.
"""
# Simplified for brevity; refer to your factory retrieval logic
return self.ingestor_factories.get(file_extension.lower())
@cachedmethod(cache=lambda self: self.file_caches)
async def ingest_file_async(self, file_id: str):
"""
Asynchronously ingest a file given its ID, utilizing the appropriate ingestor
based on file extension. Limits the concurrency to prevent resource exhaustion.
Args:
file_id (str): The unique identifier of the file to be ingested.
"""
async with self.concurrent_file_semaphore:
collected_bytes_list = self.file_caches.pop(file_id, [])
if not collected_bytes_list:
return
file_extension = collected_bytes_list[0].extension
ingestor = await self.get_factory(file_extension)
if not ingestor:
self.logger.warning(f"Unsupported file extension: {file_extension}")
return
async def chunk_generator() -> AsyncGenerator[CollectedBytes, None]:
for chunk in collected_bytes_list:
yield chunk
await self.process_file(ingestor, chunk_generator(), file_id)
async def process_file(self, ingestor: BaseIngestor, chunk_generator: AsyncGenerator[CollectedBytes, None], file_id: str):
try:
async for chunk_tokens in ingestor.ingest(chunk_generator):
await self.handle_ingested_tokens(chunk_tokens)
except Exception as e:
self.logger.error(f"Error processing file {file_id}: {e}")
async def handle_ingested_tokens(self, chunk_tokens: IngestedTokens):
if self.result_queue is not None:
await self.result_queue.put(chunk_tokens)
if self.tokens_feader is not None:
self.tokens_feader.send_tokens_in_rust(chunk_tokens)
async def ingest_all_async(self):
"""
Asynchronously ingest files from all registered collectors, with controlled concurrency
for file processing.
"""
tasks = [asyncio.create_task(self.ingest_collector_async(collector)) for collector in self.collectors]
await asyncio.gather(*tasks)
await self.finalize_ingestion()
async def finalize_ingestion(self):
if self.result_queue is not None:
await self.result_queue.put(None) # Signal the end of ingestion
# Additional methods like ingest_collector_async should follow the established patterns of error handling,
# asynchronous operation, and resource management as shown above.