ucbepic / docetl

A system for agentic LLM-powered data processing and ETL
https://docetl.org
MIT License
825 stars 79 forks source link

Cache parsed data to optimize loading and sampling #33

Open shreyashankar opened 1 week ago

shreyashankar commented 1 week ago

See PR #32 (thanks @ahmedtadde for the suggestion!)

Currently, we apply parsing tools each time we load or sample data. We could cache the output of self._apply_parsing_tools() for the entire dataset and use this cached data for both loading and sampling operations.

Benefits:

Considerations:

shreyashankar commented 1 week ago

Check out this file for how we apply parsing.

ahmedtadde commented 1 week ago

hey @shreyashankar , you can assign this to me. No promises on the turn around though; basically feel free to apply a work stealing scheduling strategy here if another person or you decide to take this before I get a draft PR out.

I sketched out what I believe is a working prototype. But I am also itching to setup a full on benchmarking script to use for profiling (using scalene) to compare the current implementation with a prospective new one that does caching. More generally, such a setup will help keep an eye on performance regression with future changes to the Dataset class, which is a paramount piece of the machinery.

the prototype is below, but there is at least one thing that I think I would change if I were to have it submitted: make caching entirely optional by the client/user. right now the class assumes that caching is always wanted

import os
from typing import List, Dict, Union, Optional, Any, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import csv
from io import StringIO
import random
from collections import OrderedDict
from functools import wraps, lru_cache

class DatasetInternalCache:
    def __init__(self, max_size: int, max_items: int):
        self.max_size = max_size
        self.max_items = max_items
        self.cache = OrderedDict()
        self.size = 0

    def __setitem__(self, key, value):
        if key in self.cache:
            self.size -= self._get_size(self.cache[key])
            del self.cache[key]
        elif len(self.cache) >= self.max_items:
            _, v = self.cache.popitem(last=False)
            self.size -= self._get_size(v)

        self.cache[key] = value
        self.size += self._get_size(value)

        while self.size > self.max_size:
            _, v = self.cache.popitem(last=False)
            self.size -= self._get_size(v)

    def __getitem__(self, key):
        value = self.cache.pop(key)
        self.cache[key] = value
        return value

    def __contains__(self, key):
        return key in self.cache

    def _get_size(self, item):
        return sum(len(str(i)) for i in item)

def dataset_internal_cache(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        key = (func.__name__,) + args + tuple(sorted(kwargs.items()))
        if key not in self._cache:
            self._cache[key] = func(self, *args, **kwargs)
        return self._cache[key]
    return wrapper

class Dataset:
    def __init__(
        self,
        type: str,
        path_or_data: Union[str, List[Dict]],
        source: str = "local",
        parsing: Optional[List[Dict[str, str]]] = None,
        user_defined_parsing_tool_map: Optional[Dict[str, ParsingTool]] = None,
        cache_size: int = 1024 * 1024 * 1024,  # 1 GiB
        cache_item_size: int = 1024 * 1024,  # 1 MiB
    ):
        self.type = self._validate_type(type)
        self.source = self._validate_source(source)
        self.path_or_data = self._validate_path_or_data(path_or_data)
        self.parsing = self._validate_parsing(parsing or [])
        self.user_defined_parsing_tool_map = user_defined_parsing_tool_map or {}
        self._cached_data = None
        self.cache_size = cache_size
        self.cache_item_size = cache_item_size
        self.max_cache_items = max(1, self.cache_size // self.cache_item_size)
        self._cache = DatasetInternalCache(self.cache_size, self.max_cache_items)

    def _validate_type(self, type: str) -> str:
        if type not in ["file", "memory"]:
            raise ValueError("Type must be 'file' or 'memory'")
        return type

    def _validate_source(self, source: str) -> str:
        if source != "local":
            raise ValueError("Source must be 'local'")
        return source

    def _validate_path_or_data(
        self, path_or_data: Union[str, List[Dict]]
    ) -> Union[str, List[Dict]]:
        if self.type == "file":
            if not isinstance(path_or_data, str):
                raise ValueError("For type 'file', path_or_data must be a string")
            if not path_or_data.lower().endswith((".json", ".csv")):
                raise ValueError("Path must end with .json or .csv")
        elif not isinstance(path_or_data, list):
            raise ValueError(
                "For type 'memory', path_or_data must be a list of dictionaries"
            )
        return path_or_data

    def _validate_parsing(self, parsing_tools: List[Dict[str, str]]) -> List[Dict[str, str]]:
        for tool in parsing_tools:
            if not all(key in tool for key in ("input_key", "function", "output_key")):
                raise ValueError(
                    "Each parsing tool must have 'input_key', 'function', and 'output_key' keys"
                )
            if not all(isinstance(tool[key], str) for key in ("input_key", "function", "output_key")):
                raise ValueError(
                    "'input_key', 'function', and 'output_key' in parsing tools must be strings"
                )
            if "function_kwargs" in tool and not isinstance(tool["function_kwargs"], dict):
                raise ValueError("'function_kwargs', if present, must be a dictionary")
        return parsing_tools

    @lru_cache(maxsize=1)
    def load(self) -> List[Dict]:
        if self.type == "memory":
            return self.path_or_data
        else:
            ext = os.path.splitext(self.path_or_data.lower())[1]
            loader = self._json_loader if ext == ".json" else self._csv_loader
            with open(self.path_or_data, "r") as f:
                return loader(f)

    @staticmethod
    def _json_loader(file) -> List[Dict]:
        return json.load(file)

    @staticmethod
    def _csv_loader(file) -> List[Dict]:
        csv_data = StringIO(file.read())
        return list(csv.DictReader(csv_data))

    def _hash_data(self, data: List[Dict]) -> int:
        return hash(tuple(sorted((k, str(v)) for item in data for k, v in item.items())))

    @dataset_internal_cache
    def _apply_parsing_tools(self, data_hash: int) -> List[Dict]:
        data = self.load()
        for tool in self.parsing:
            input_key = tool["input_key"]
            output_key = tool["output_key"]
            function_kwargs = tool.get("function_kwargs", {})

            if tool["function"] in PARSING_TOOLS:
                func = PARSING_TOOLS[tool["function"]]
            elif tool["function"] in self.user_defined_parsing_tool_map:
                func = eval(self.user_defined_parsing_tool_map[tool["function"]].function_code)
            else:
                raise ValueError(f"Parsing tool {tool['function']} not found")

            with ThreadPoolExecutor() as executor:
                futures = [
                    executor.submit(self._process_item, item, input_key, output_key, func, **function_kwargs)
                    for item in data
                ]
                data = [item for future in as_completed(futures) for item in future.result()]

        return data

    def _process_item(
        self,
        item: Dict[str, Any],
        input_key: str,
        output_key: str,
        func: Callable,
        **function_kwargs: Dict[str, Any],
    ) -> List[Dict[str, Any]]:
        if input_key not in item:
            raise KeyError(f"Input key {input_key} not found in item: {item}")
        result = func(item[input_key], **function_kwargs)
        return [item | {output_key: res} for res in (result if isinstance(result, list) else [result])]

    def get_processed_data(self) -> List[Dict]:
        raw_data = self.load()
        data_hash = self._hash_data(raw_data)
        return self._apply_parsing_tools(data_hash)

    def sample(self, n: int, random_sample: bool = True) -> List[Dict]:
        if n <= 0:
            raise ValueError("Sample size must be positive")

        processed_data = self.get_processed_data()

        if n > len(processed_data):
            raise ValueError(f"Sample size {n} is larger than dataset size {len(processed_data)}")

        if random_sample:
            return random.sample(processed_data, n)
        else:
            return processed_data[:n]

    def clear_cache(self):
        self._cache.cache.clear()
        self._cache.size = 0
        self.load.cache_clear()
shreyashankar commented 1 week ago

thank you for taking this on! at a glance i think we will want to use a disk cache, so data persists between pipeline runs. for example, we use DiskCache here

no rush on the timeline; your contributions are much appreciated 😊