NickCrews / mismo

The SQL/Ibis powered sklearn of record linkage
https://nickcrews.github.io/mismo/
GNU Lesser General Public License v3.0
12 stars 3 forks source link

FEAT: Implement Pipelines? #18

Closed OlivierBinette closed 6 months ago

OlivierBinette commented 6 months ago

Should there be the equivalent of sklearn pipelines?

To be honest, I'm not a huge fan of how sklearn pipelines are implemented. They're hard to use, and you can't have one step refer to previous steps in the pipeline. Here's how I implemented a basic pipelines prototype for a previous project:

We have steps, i.e. functions with a specified set of dependencies (here named inputs). There are two ways two execute steps:

  1. As a Callable
  2. Using the compute function and a cache dictionary. When using the compute function, step dependencies will be computed recursively using relevant content of the cache as inputs.

Note: currently, step names in any dependency graph should be unique. Ideally we'd create UUIDs for steps.

class Step(Protocol, Callable):
    name: str
    inputs: list["Step"]

    def compute(self, cache: dict[str, any]):
        ...

And we have Pipelines, i.e. series of steps. Steps that are explicitely listed in the pipeline are fed to callback functions, i.e. for persisting results or logging progress:

class Pipeline(ABC):
    name: str
    description: str
    version: str

    def __init__(self):
        self.callbacks = []

    @property
    @abstractmethod
    def steps(self) -> list[Step]:
        ...

    def run(self, **inputs):
        cache = inputs
        for step in self.steps:
            output = step.compute(cache)
            for callback in self.callbacks:
                callback(output)
        return cache

Example

Here are some steps:


@step()
def x() -> int:  # This is a non-implemented step, i.e. "x" needs to be explicitely given as an input.
    ...

@step(inputs=[x])
def plus_one(x: int) -> int:
    return x+1

@step(inputs=[plus_one])
def times_two(y: int) -> int:
    return 2*y

You can use step functions directly:

times_two(1)
# 1

Or you can compute steps using the dependency graph:

times_two.compute({"x": 1})  # Computes 2*(x+1)
# 4

Full implementation:

from functools import update_wrapper
from typing import Callable, Protocol

class Step(Protocol, Callable):
    name: str
    inputs: list["Step"]

    def compute(self, cache: dict[str, any]):
        ...

def step(inputs: list[Step] = None):
    """Decorator for functions that are steps in a pipeline.

    Parameters
    ----------
    inputs : list[Step]
        List of steps that are inputs to this step. The decorated function should expect the output of these steps as arguments.

    Returns
    -------
    Step
        Decorated function, a callable Step instance.
    """

    def decorator(func) -> Step:
        return _Step(func, inputs)

    return decorator

def stepmethod(inputs: list[Step] = None):
    """Decorator for instance methods that are steps in a pipeline.

    This should be used instead of `step` when working with instance methods. Stepmethod ensures that input steps are binded to the calling instance prior to computation.

    Parameters
    ----------
    inputs : list[Step]
        List of steps that are inputs to this step. The decorated instance method should expect the output of these steps as arguments.

    Returns
    -------
    Step
        Decorated instance method, a callable Step instance.
    """

    def decorator(func) -> Step:
        return _StepMethod(func, inputs)

    return decorator

class _Step(Step):
    def __init__(self, func, inputs: list["Step"]):
        self.name = func.__name__
        self.func = func
        update_wrapper(self, func)
        self.inputs = inputs or []

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)

    def compute(self, cache: dict[str, any]):
        if self.name in cache:
            return cache[self.name]

        names = [step.name for step in self.inputs]
        for i, name in enumerate(names):
            if name not in cache:
                cache[name] = self.inputs[i].compute(cache)

        args = [cache[step.name] for step in self.inputs]
        cache[self.name] = self.func(*args)

        return cache[self.name]

class _StepMethod(_Step):
    def __init__(self, func, inputs: list["Step"]):
        super().__init__(func, inputs)

    def __set_name__(self, owner, name):
        self.name = name

    def __get__(self, obj, cls=None):
        for step in self.inputs:
            step.__get__(obj, cls)
        self.func = self.func.__get__(obj, cls)
        return self
OlivierBinette commented 6 months ago

Note: this definition of a pipeline is for series of data transformation steps. This is different from sklearn, where pipelines are used to compose estimators.

Also, a robust implementation of my idea would require checking that dependencies graph are a DAG, maybe getting a topological ordering, and better managing step names and their uniqueness. It's not necessary too hard to do, especially using Python's built-in graph package, but it's something to consider in terms of the complexity of the solution.

NickCrews commented 6 months ago

This should happen, but I don't think it should be in mismo. task orchestration is general enough that it should go in it's own module. But, mismo should be designed in a way that it is easy to use FROM this orchestration framework.

FYI, I have my own implementation of something similar that wraps pytask, take a look:

Click to Expand ```python from __future__ import annotations from collections.abc import Mapping from functools import cached_property import logging from pathlib import Path from typing import Any, Hashable, Iterable, Protocol, TypeVar import fire import ibis from ibis.expr.types import Table from noatak import io import pytask from kobuk import data_folder logger = logging.getLogger(__name__) class PInputs(Protocol): """The specification for the inputs to a task function. This is used to convert input data from a name:value mapping to the args and kwargs that the task function expects. """ @property def names(self) -> frozenset[str]: """The names of all inputs""" raise NotImplementedError def convert_input( self, input: Mapping[str, Any] ) -> tuple[tuple[Any], dict[str, Any]]: """ Convert a k:v mapping to args and kwargs suitable for passing to a function """ raise NotImplementedError def __repr__(self) -> str: return f"{self.__class__.__name__}({set(sorted(self.names))})" class Inputs(PInputs): def __init__(self, args: Iterable[str], kwargs: Mapping[str, str]): self.args = tuple(args) self.kwargs = dict(kwargs) # TODO make this immutable # map of arg_name: input_name # arg_name must be the key because one input might be used for multiple args @cached_property def names(self) -> frozenset[str]: return frozenset(self.args) | frozenset(self.kwargs.values()) def convert_input( self, input: Mapping[str, Any] ) -> tuple[tuple[Any], dict[str, Any]]: args = tuple(input[arg] for arg in self.args) kwargs = {arg: input[name] for arg, name in self.kwargs.items()} return args, kwargs @classmethod def make( cls, inputs: str | Iterable[str] | Mapping[str, str] | None, ): if inputs is None: args = tuple() kwargs = {} return cls(args, kwargs) elif isinstance(inputs, str): args = (inputs,) kwargs = {} return cls(args, kwargs) try: kwargs = dict(inputs.items()) args = tuple() return cls(args, kwargs) except AttributeError: args = tuple(inputs) kwargs = {} return cls(args, kwargs) class POutputs(Protocol): """The specification for the outputs of a task function.""" @property def names(self) -> frozenset[str]: """The names of all outputs""" return self._names def convert_output(self, output: Any) -> dict[str, Any]: """Convert the output of a task function to a name:value mapping.""" raise NotImplementedError @classmethod def make(cls, names: str | Iterable[str] | Mapping[str, Hashable] | None): if names is None: return NoneOutputs() elif isinstance(names, str): return ScalarOutputs(names) try: kwouts = dict(names.items()) return MappingOutputs(kwouts) except AttributeError: return IterableOutputs(names) def __repr__(self) -> str: return f"{self.__class__.__name__}({set(sorted(self.names))})" class NoneOutputs(POutputs): _names = frozenset() def convert_output(self, output: Any) -> dict[str, Any]: return dict() class ScalarOutputs(POutputs): def __init__(self, name: str): self._name = name self._names = frozenset((name,)) def convert_output(self, output: Any) -> dict[str, Any]: return {self._name: output} class IterableOutputs(POutputs): def __init__(self, names: Iterable[str]): self._names = tuple(names) def convert_output(self, output: Any) -> dict[str, Any]: return dict(zip(self._names, output)) class MappingOutputs(POutputs): def __init__(self, names: Mapping[str, str]): self._kwargs = dict(names) self._names = frozenset(names.keys()) def convert_output(self, output: Any) -> dict[str, Any]: return {name: output[raw_name] for name, raw_name in self._kwargs.items()} class PLoader(Protocol): @property def names(self) -> frozenset[str]: """The names of the data that can be loaded.""" raise NotImplementedError def path(self, name: str) -> Path: raise NotImplementedError def load(self, name: str) -> Any: raise NotImplementedError class PSaver(Protocol): @property def names(self) -> frozenset[str]: """The names of the data that can be saved.""" raise NotImplementedError def path(self, name: str) -> Path: raise NotImplementedError def save(self, name: str, data: Any) -> None: raise NotImplementedError _LoaderOrSaverT = TypeVar("_LoaderOrSaverT", PLoader, PSaver) class DataStore: def __init__( self, loaders: Iterable[PLoader] = [], savers: Iterable[PSaver] = [], data: dict[str, Any] | None = None, ): self._loaders = frozenset(loaders) self._savers = frozenset(savers) self.datas = dict(data) if data is not None else {} try: self._loader_map = self.build_map(loaders) except ValueError as e: raise ValueError("Duplicate loader names") from e try: self._saver_map = self.build_map(savers) except ValueError as e: raise ValueError("Duplicate saver names") from e def evolve(self, loaders=None, savers=None, data=None): return DataStore( loaders=loaders or self._loaders, savers=savers or self._savers, data=data or self.datas, ) def get_loader(self, name: str) -> PLoader: try: return self._loader_map[name] except KeyError as e: raise ValueError(f"No loader for {name}") from e def get_saver(self, name: str) -> PSaver: try: return self._saver_map[name] except KeyError as e: raise ValueError(f"No saver for {name}") from e def load(self, name: str) -> Any: if name in self.datas: return self.datas[name] result = self.get_loader(name).load(name) self.datas[name] = result return result def save(self, name: str, data: Any) -> None: saver = self.get_saver(name) saver.save(name, data) self.datas[name] = data @staticmethod def build_map(ls: Iterable[_LoaderOrSaverT]) -> dict[str, _LoaderOrSaverT]: m = {} duplicates = set() for loader in ls: for name in loader.names: if name in m: duplicates.add(name) else: m[name] = loader if duplicates: raise ValueError(duplicates) return m class Task: def __init__( self, func, *, inputs: PInputs, outputs: POutputs, data_store: DataStore, name: str | None = None, ): self.func = func self.inputs = inputs self.outputs = outputs self.data_store = data_store self.name = name or func.__name__ def load(self) -> tuple[tuple[Any], dict[str, Any]]: logger.info(f"Reading {self.inputs}") input = {name: self.data_store.load(name) for name in self.inputs.names} return self.inputs.convert_input(input) def save(self, raw_output: Any): output = self.outputs.convert_output(raw_output) logger.info(f"Writing {self.outputs}") for name, data in output.items(): self.data_store.save(name, data) def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) def cli(self, verbose: bool = False): level = logging.DEBUG if verbose else logging.INFO logging.basicConfig(level=level, force=True) self.run() def fire(self): fire.Fire(self.cli) def run(self): logger.info(f"Running {self.func.__name__}") args, kwargs = self.load() result = self.func(*args, **kwargs) self.save(result) def to_pytask(self): """Convert into a pytask task function.""" input_paths = frozenset( (self.data_store.get_loader(name).path(name) for name in self.inputs.names) ) output_paths = frozenset( (self.data_store.get_saver(name).path(name) for name in self.outputs.names) ) @pytask.mark.depends_on(input_paths) @pytask.mark.produces(output_paths) def my_task(depends_on, produces): self.cli() return my_task class _TableLoaderSaverBase: def __init__(self, names: Iterable[str], path_getter): self.names = frozenset(names) self.path_getter = path_getter def path(self, name: str) -> Path: return self.path_getter(name) class TableLoader(_TableLoaderSaverBase): def load(self, name: str) -> Table: return ibis.read_parquet(self.path(name)) class TableSaver(_TableLoaderSaverBase): def save(self, name: str, data: Table) -> None: io.to_parquet(data, self.path(name)) def make_pytask(func, ins=None, outs=None, *, prefix: str | None = None): def path_getter(name): if prefix is None: return data_folder() / f"{name}.parquet" else: return data_folder() / prefix / f"{name}.parquet" i = Inputs.make(ins) o = POutputs.make(outs) loader = TableLoader(i.names, path_getter) saver = TableSaver(o.names, path_getter) data_store = DataStore(loaders=[loader], savers=[saver]) return Task(func, inputs=i, outputs=o, data_store=data_store).to_pytask() def run_task(name: str): # s flag means don't capture stdout, so # we get logging output as tasks run session = pytask.main({"k": name, "s": True}) if session.exit_code != 0: raise RuntimeError(f"Task {name} failed", session) ```