HK3-Lab-Team / pytrousse

PyTrousse collects into one toolbox a set of data wrangling procedures tailored for composing reproducible analytics pipelines.
Apache License 2.0
0 stars 1 forks source link

Protocol Based Trousse Feature Operation Pipeline w/ Validation #92

Open leriomaggio opened 3 years ago

leriomaggio commented 3 years ago

Hiya @alessiamarcolini @lorenz-gorini

as promised, here is the Protocol-based Feature operation pipeline design sketch I had in mind, with a simple Trousse pipeline implementation w/ validation

Note: In this design, the Trousse pipeline object is not a FeatureOperation subclass. Nothing wrong with it, but not needed anyway.

Note2: PandasDataset and TensorDataset are two dummy types: validation checks should be replaced, accordingly.

Please find below the full gist, also available here

from typing import Tuple, List, Optional, Union, NewType
from typing import Protocol, runtime_checkable
from abc import ABC, abstractmethod
import pandas as pd
try:
    import torch as t
    from torch import Tensor as Tensor
except ImportError: # fallback to Numpy, if torch not available - this is just as a POP!
    import numpy as t
    from numpy import ndarray as Tensor

PandasDataset = NewType("PandasDataset", pd.DataFrame)
TensorDataset = NewType("TensorDataset", Tensor)
Dataset = Union[PandasDataset, TensorDataset]

# Protocols
# =========
@runtime_checkable
class DataFrameProtocol(Protocol):

    def apply_pandas(self, ds: PandasDataset) -> PandasDataset:
        ...

@runtime_checkable
class TensorProtocol(Protocol):

    def apply_torch(self, ds: TensorDataset) -> TensorDataset:
        ...

# ---------------------------------------------------------------

# Feature Operation ABC
# =====================
class FeatureOperation(ABC):

    def __call__(self, ds: Dataset) -> Dataset:
        # do stuff
        ds = self.apply(ds)
        # do other stuff
        return ds

    @abstractmethod
    def apply(self, ds: Dataset) -> Dataset:
        ...

    def __str__(self) -> str:
        return self.__class__.__name__

# ---------------------------------------------------------------
# Concrete Feature Operation Implementation w/ Protocol Adherence
# ===============================================================
class DataFrameIdentityOperation(FeatureOperation):

    def apply(self, ds: Dataset) -> Dataset:
        return self.apply_pandas(ds)

    def apply_pandas(self, ds: PandasDataset) -> PandasDataset:
        print("Pandas Dataset")
        return ds

class TensorIdentityOperation(FeatureOperation):

    def apply(self, ds: Dataset) -> Dataset:
        return self.apply_torch(ds)

    def apply_torch(self, ds: TensorDataset) -> TensorDataset:
        print("Tensor Dataset")
        return ds

# ---------------------------------------------------------------
# Trousse Pipeline
# ================

class TrousseValidationError(RuntimeError):
    pass

class Trousse:

    def __init__(self, *ops: FeatureOperation) -> None:
        self._operations = list()
        for op in ops:
            self._operations.append(op)

    def add_module(self, op: FeatureOperation) -> None:
        self._operations.append(op)

    @property
    def operations(self):
        return self._operations

    def __call__(self, ds: Dataset) -> Dataset:
        valid, errors = self.is_valid(ds, with_errors=True)
        if not valid:
            raise TrousseValidationError(f"Invalid Ops: {errors}")
        for op in self._operations:
            ds = op(ds)
        return ds

    def is_valid(self, ds: Dataset, with_errors: bool = False) -> Tuple[bool, Optional[List[str]]]:
        if isinstance(ds, pd.DataFrame):  # subs w/ real PandasDataset
            invalid_ops = filter(lambda op: not(isinstance(op, DataFrameProtocol)), self._operations)
        elif isinstance(ds, Tensor):  # subs w/ real TensorDataset
            invalid_ops = filter(lambda op: not(isinstance(op, TensorProtocol)), self._operations)
        else:
            invalid_ops = ("Unsupported Dataset", )

        if with_errors:
            str_ops = [str(iop) for iop in invalid_ops]
            return len(str_ops) == 0, str_ops
        else:
            return len(list(invalid_ops)) == 0, None

    def __str__(self) -> str:
        return f"Trousse: {[str(op) for op in self._operations]}"

if __name__ == "__main__":
    di = DataFrameIdentityOperation()
    ti = TensorIdentityOperation()

    print(f"{di} implements PandasProtocol: {isinstance(di, DataFrameProtocol)}")
    print(f"{di} implements TorchProtocol: {isinstance(di, TensorProtocol)}")

    print(f"{ti} implements PandasProtocol: {isinstance(ti, DataFrameProtocol)}")
    print(f"{ti} implements TorchProtocol: {isinstance(ti, TensorProtocol)}")

    print(f"{di} is Callable: {callable(di)}")
    print(f"{ti} is Callable: {callable(ti)}")

    from string import punctuation

    df = pd.DataFrame({"Char": list(punctuation), "No": list(range(len(punctuation)))})
    ds_pd = PandasDataset(df)
    ds_to = TensorDataset(t.ones((3, 4)))

    trousse = Trousse(DataFrameIdentityOperation())
    print(f"Pipeline: {str(trousse)}")

    print("Trousse is valid w/ Pandas Dataset: ", trousse.is_valid(ds_pd))
    print("Trousse is valid w/ Tensor Dataset: ", trousse.is_valid(ds_to))

    print("Run Pipeline")
    trousse(ds_pd)

    try:
        print("Raise Error: ")
        trousse(ds_to)
    except TrousseValidationError as e:
        print(f"Exception ==> {e}")