Open shreyashankar opened 1 week ago
Check out this file for how we apply parsing.
hey @shreyashankar , you can assign this to me. No promises on the turn around though; basically feel free to apply a work stealing scheduling strategy here if another person or you decide to take this before I get a draft PR out.
I sketched out what I believe is a working prototype. But I am also itching to setup a full on benchmarking script to use for profiling (using scalene) to compare the current implementation with a prospective new one that does caching. More generally, such a setup will help keep an eye on performance regression with future changes to the Dataset
class, which is a paramount piece of the machinery.
the prototype is below, but there is at least one thing that I think I would change if I were to have it submitted: make caching entirely optional by the client/user. right now the class assumes that caching is always wanted
import os
from typing import List, Dict, Union, Optional, Any, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import csv
from io import StringIO
import random
from collections import OrderedDict
from functools import wraps, lru_cache
class DatasetInternalCache:
def __init__(self, max_size: int, max_items: int):
self.max_size = max_size
self.max_items = max_items
self.cache = OrderedDict()
self.size = 0
def __setitem__(self, key, value):
if key in self.cache:
self.size -= self._get_size(self.cache[key])
del self.cache[key]
elif len(self.cache) >= self.max_items:
_, v = self.cache.popitem(last=False)
self.size -= self._get_size(v)
self.cache[key] = value
self.size += self._get_size(value)
while self.size > self.max_size:
_, v = self.cache.popitem(last=False)
self.size -= self._get_size(v)
def __getitem__(self, key):
value = self.cache.pop(key)
self.cache[key] = value
return value
def __contains__(self, key):
return key in self.cache
def _get_size(self, item):
return sum(len(str(i)) for i in item)
def dataset_internal_cache(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
key = (func.__name__,) + args + tuple(sorted(kwargs.items()))
if key not in self._cache:
self._cache[key] = func(self, *args, **kwargs)
return self._cache[key]
return wrapper
class Dataset:
def __init__(
self,
type: str,
path_or_data: Union[str, List[Dict]],
source: str = "local",
parsing: Optional[List[Dict[str, str]]] = None,
user_defined_parsing_tool_map: Optional[Dict[str, ParsingTool]] = None,
cache_size: int = 1024 * 1024 * 1024, # 1 GiB
cache_item_size: int = 1024 * 1024, # 1 MiB
):
self.type = self._validate_type(type)
self.source = self._validate_source(source)
self.path_or_data = self._validate_path_or_data(path_or_data)
self.parsing = self._validate_parsing(parsing or [])
self.user_defined_parsing_tool_map = user_defined_parsing_tool_map or {}
self._cached_data = None
self.cache_size = cache_size
self.cache_item_size = cache_item_size
self.max_cache_items = max(1, self.cache_size // self.cache_item_size)
self._cache = DatasetInternalCache(self.cache_size, self.max_cache_items)
def _validate_type(self, type: str) -> str:
if type not in ["file", "memory"]:
raise ValueError("Type must be 'file' or 'memory'")
return type
def _validate_source(self, source: str) -> str:
if source != "local":
raise ValueError("Source must be 'local'")
return source
def _validate_path_or_data(
self, path_or_data: Union[str, List[Dict]]
) -> Union[str, List[Dict]]:
if self.type == "file":
if not isinstance(path_or_data, str):
raise ValueError("For type 'file', path_or_data must be a string")
if not path_or_data.lower().endswith((".json", ".csv")):
raise ValueError("Path must end with .json or .csv")
elif not isinstance(path_or_data, list):
raise ValueError(
"For type 'memory', path_or_data must be a list of dictionaries"
)
return path_or_data
def _validate_parsing(self, parsing_tools: List[Dict[str, str]]) -> List[Dict[str, str]]:
for tool in parsing_tools:
if not all(key in tool for key in ("input_key", "function", "output_key")):
raise ValueError(
"Each parsing tool must have 'input_key', 'function', and 'output_key' keys"
)
if not all(isinstance(tool[key], str) for key in ("input_key", "function", "output_key")):
raise ValueError(
"'input_key', 'function', and 'output_key' in parsing tools must be strings"
)
if "function_kwargs" in tool and not isinstance(tool["function_kwargs"], dict):
raise ValueError("'function_kwargs', if present, must be a dictionary")
return parsing_tools
@lru_cache(maxsize=1)
def load(self) -> List[Dict]:
if self.type == "memory":
return self.path_or_data
else:
ext = os.path.splitext(self.path_or_data.lower())[1]
loader = self._json_loader if ext == ".json" else self._csv_loader
with open(self.path_or_data, "r") as f:
return loader(f)
@staticmethod
def _json_loader(file) -> List[Dict]:
return json.load(file)
@staticmethod
def _csv_loader(file) -> List[Dict]:
csv_data = StringIO(file.read())
return list(csv.DictReader(csv_data))
def _hash_data(self, data: List[Dict]) -> int:
return hash(tuple(sorted((k, str(v)) for item in data for k, v in item.items())))
@dataset_internal_cache
def _apply_parsing_tools(self, data_hash: int) -> List[Dict]:
data = self.load()
for tool in self.parsing:
input_key = tool["input_key"]
output_key = tool["output_key"]
function_kwargs = tool.get("function_kwargs", {})
if tool["function"] in PARSING_TOOLS:
func = PARSING_TOOLS[tool["function"]]
elif tool["function"] in self.user_defined_parsing_tool_map:
func = eval(self.user_defined_parsing_tool_map[tool["function"]].function_code)
else:
raise ValueError(f"Parsing tool {tool['function']} not found")
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(self._process_item, item, input_key, output_key, func, **function_kwargs)
for item in data
]
data = [item for future in as_completed(futures) for item in future.result()]
return data
def _process_item(
self,
item: Dict[str, Any],
input_key: str,
output_key: str,
func: Callable,
**function_kwargs: Dict[str, Any],
) -> List[Dict[str, Any]]:
if input_key not in item:
raise KeyError(f"Input key {input_key} not found in item: {item}")
result = func(item[input_key], **function_kwargs)
return [item | {output_key: res} for res in (result if isinstance(result, list) else [result])]
def get_processed_data(self) -> List[Dict]:
raw_data = self.load()
data_hash = self._hash_data(raw_data)
return self._apply_parsing_tools(data_hash)
def sample(self, n: int, random_sample: bool = True) -> List[Dict]:
if n <= 0:
raise ValueError("Sample size must be positive")
processed_data = self.get_processed_data()
if n > len(processed_data):
raise ValueError(f"Sample size {n} is larger than dataset size {len(processed_data)}")
if random_sample:
return random.sample(processed_data, n)
else:
return processed_data[:n]
def clear_cache(self):
self._cache.cache.clear()
self._cache.size = 0
self.load.cache_clear()
thank you for taking this on! at a glance i think we will want to use a disk cache, so data persists between pipeline runs. for example, we use DiskCache
here
no rush on the timeline; your contributions are much appreciated 😊
See PR #32 (thanks @ahmedtadde for the suggestion!)
Currently, we apply parsing tools each time we load or sample data. We could cache the output of
self._apply_parsing_tools()
for the entire dataset and use this cached data for both loading and sampling operations.Benefits:
Considerations: