unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.27k stars 305 forks source link

Extend SchemaModel so that it's a valid type in a pydantic model #453

Closed tfwillems closed 2 years ago

tfwillems commented 3 years ago

Is your feature request related to a problem? Please describe. I'd like to include pandera SchemaModels as a type within a pydantic model. pandera's SchemaModel best encapsulates a valid "type" for pandas data frames (as far as I'm aware), but it's currently missing a single classmethod that would make this possible. Ideally, we'd be able to do the following:

import pandera as pa
import pandas as pd
from pydantic import BaseModel
from typing import Any

class SimpleDF(pa.SchemaModel):
     sample: pa.typing.Series[str] = pa.Field(allow_duplicates=False)
     val_1: pa.typing.Series[str]
     val_2: pa.typing.Series[int]

# Trying to create this pydantic model current fails, as SimpleDF/SchemaModel do not provide 
# the __get_validators__ function required by pydantic
class SimplePydanticModel(BaseModel):
    x: int
    y: SimpleDF

# Ideally should work without error
valid_df = pd.DataFrame({"sample": ["hello", "world"], "val_1": ["a", "b"], "val_2":[1, 2]})
model = SimplePydanticModel(x=1, y=valid_df)

# Ideally should fail as invalid_df has duplicate values in sample
invalid_df = pd.DataFrame({"sample": ["hello", "hello"], "val_1": ["a", "b"], "val_2":[1, 2]})
invalid_model = SimplePydanticModel(x=1, y=invalid_df)

Describe the solution you'd like Extend pandera's SchemaModel so that it's a valid type within pydantic. Nearly all of the heavy lifting has already been done, as the validate function essentially performs all validation. pydantic naturally supports custom types, as long as they implement the __get_validators__ classmethod see here. We could simply extend SchemaModel by adding this classmethod or some variant thereof:

@classmethod
def __get_validators__(cls):
    yield cls._pydantic_validate

@classmethod
def _pydantic_validate(cls, df: Any) -> pd.DataFrame:
   if not isinstance(df, pd.DataFrame):
      raise ValueError("Expected a pandas data frame")

   try:
      return cls.validate(df)
   except pa.errors.SchemaError as e:
      raise ValueError(str(e))

Describe alternatives you've considered This may be outside the scope of the project, but I envision others may be interested. pydantic supports a wide range of types, but is currently missing one that naturally extends to data frames. I think adding this feature would mean that pandera naturally fills this void. Of course one option would be for me to implement my own subclass of SchemaModel that provides this classmethod, but if others might find it useful it'd be great to add it to original parent class.

jeffzi commented 3 years ago

SchemaModel was inspired by pydantic. I think it only makes sense to be compatible with it. Pinging the big boss @cosmicBboy to make sure that's ok with him :)

That said, we need to use the same annotation as the pandera.check_types decorator, e.g . pandera.typing.DataFrame[SimpleDF]. That indicates that we expect a DataFrame "typed" by SimpleDF, whereas the SimpleDF annotation indicates that we expect a (sub)type of the model SimpleDF.

Example of the syntax:

import pandera as pa
from pandera.typing import Index, DataFrame

class InputSchema(pa.SchemaModel):
    year: Series[int]

class OutputSchema(InputSchema):
  revenue: Series[float]

@pa.check_types
def transform(df: DataFrame[InputSchema]) -> DataFrame[OutputSchema]:
    return df.assign(revenue=100.0)

We'd need to add __get_validators__ to pandera.typing.DataFrame, with a validator method of signature def validate(cls, v, field: ModelField): see pydantic doc about generic types here.

Would you be up to submit a pull request? In any case, thanks for your recent feedbacks. Keep them coming :+1:

cosmicBboy commented 3 years ago

sounds good to me, thanks @tfwillems !

tfwillems commented 3 years ago

@jeffzi @cosmicBboy I started working on a simple prototype for this that would replace DataFrame in pandera.typing but unfortunately ran into an issue.

Currently, pydantic supports Generic classes as types, but if type argument(s) are supplied, it seems that they must also implement __get_validators__(). So this would require that if we wanted to use the framework @jeffzi suggested, we would still need to add a __get_validators__() to the SchemaModel class directly as it would be the generic argument for DataFrame.

Here's how I envisioned the prototype could potentially function:

import inspect
import pandas as pd
import pandera as pa
import pydantic

from typing import Generic, TypeVar, Tuple

Schema = TypeVar("Schema", bound="SchemaModel")  # type: ignore

class DataFrame(pd.DataFrame, Generic[Schema]):
    """
    Representation of pandas.DataFrame, only used for type annotation.

    *new in 0.5.0*
    """

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v, field: pydantic.fields.ModelField):
        if not isinstance(v, pd.DataFrame):
            raise ValueError(f"Expected an instance of pandas.DataFrame but found an instance of '{type(v)}'")

        # Generic parameter was not provided so don't validate schema
        if field.sub_fields is None:
            return v

        # Verify that the generic argument is a subclass of SchemaModel
        generic_type = field.sub_fields[0].outer_type_
        if not inspect.isclass(generic_type) or not issubclass(generic_type, pa.SchemaModel):
            raise TypeError(
                f"Invalid generic argument for {cls.__name__}."
                f" Expected a subclass of SchemaModel but found '{generic_type}'"
            )

        try:
            return generic_type.validate(df)
        except pa.errors.SchemaError as e:
            raise ValueError(str(e))

But unfortunately when I test it as below, I receive a slew of pydantic errors that ultimately suggest you need to provide a __get_validators__ fn in SchemaModel:


from pydantic import BaseModel
df_1 = pd.DataFrame({"x": [1,2,3]})

class MySchemaModel(pa.SchemaModel):
    x: pa.typing.Series[int] = pa.Field()

class MyModel(BaseModel):
    x: DataFrame[MySchemaModel]

Traceback (most recent call last):
  File "/cluster/home/willems/MethodsDev/pandera-extensions/pandera/pandera/testing123.py", line 74, in <module>
    class MyModel(BaseModel):
  File "pydantic/main.py", line 293, in pydantic.main.ModelMetaclass.__new__
  File "pydantic/fields.py", line 410, in pydantic.fields.ModelField.infer
  File "pydantic/fields.py", line 342, in pydantic.fields.ModelField.__init__
  File "pydantic/fields.py", line 450, in pydantic.fields.ModelField.prepare
  File "pydantic/fields.py", line 622, in pydantic.fields.ModelField._type_analysis
  File "pydantic/fields.py", line 648, in pydantic.fields.ModelField._create_sub_type
  File "pydantic/fields.py", line 342, in pydantic.fields.ModelField.__init__
  File "pydantic/fields.py", line 456, in pydantic.fields.ModelField.prepare
  File "pydantic/fields.py", line 670, in pydantic.fields.ModelField.populate_validators
  File "pydantic/validators.py", line 715, in find_validators
RuntimeError: no validator found for <class '__main__.MySchemaModel'>, see `arbitrary_types_allowed` in Config

@jeffzi To verify this wasn't solely due to an obvious implementation issue on my part, I reproduced this using the generic class example in the pydantic docs involving AgedType and QualityType.

from pydantic import BaseModel, ValidationError
from pydantic.fields import ModelField
from typing import TypeVar, Generic

AgedType = TypeVar('AgedType')
QualityType = TypeVar('QualityType')

# This is not a pydantic model, it's an arbitrary generic class
class TastingModel(Generic[AgedType, QualityType]):
    def __init__(self, name: str, aged: AgedType, quality: QualityType):
        self.name = name
        self.aged = aged
        self.quality = quality

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    # You don't need to add the "ModelField", but it will help your
    # editor give you completion and catch errors
    def validate(cls, v, field: ModelField):
        if not isinstance(v, cls):
            # The value is not even a TastingModel
            raise TypeError('Invalid value')
        if not field.sub_fields:
            # Generic parameters were not provided so we don't try to validate
            # them and just return the value as is
            return v
        aged_f = field.sub_fields[0]
        quality_f = field.sub_fields[1]
        errors = []
        # Here we don't need the validated value, but we want the errors
        valid_value, error = aged_f.validate(v.aged, {}, loc='aged')
        if error:
            errors.append(error)
        # Here we don't need the validated value, but we want the errors
        valid_value, error = quality_f.validate(v.quality, {}, loc='quality')
        if error:
            errors.append(error)
        if errors:
            raise ValidationError(errors, cls)
        # Validation passed without errors, return the same instance received
        return v

# My code starts here
class OkModel(BaseModel):
    x: TastingModel[int, float]

class CustomAgedType:
    pass

class FailModel(BaseModel):
    x: TastingModel[CustomAgedType, float]

Traceback (most recent call last):
  File "/cluster/home/willems/MethodsDev/pandera-extensions/pandera/pandera/testing123.py", line 137, in <module>
    class FailModel(BaseModel):
  File "pydantic/main.py", line 293, in pydantic.main.ModelMetaclass.__new__
  File "pydantic/fields.py", line 410, in pydantic.fields.ModelField.infer
  File "pydantic/fields.py", line 342, in pydantic.fields.ModelField.__init__
  File "pydantic/fields.py", line 450, in pydantic.fields.ModelField.prepare
  File "pydantic/fields.py", line 622, in pydantic.fields.ModelField._type_analysis
  File "pydantic/fields.py", line 648, in pydantic.fields.ModelField._create_sub_type
  File "pydantic/fields.py", line 342, in pydantic.fields.ModelField.__init__
  File "pydantic/fields.py", line 456, in pydantic.fields.ModelField.prepare
  File "pydantic/fields.py", line 670, in pydantic.fields.ModelField.populate_validators
  File "pydantic/validators.py", line 715, in find_validators
RuntimeError: no validator found for <class '__main__.CustomAgedType'>, see `arbitrary_types_allowed` in Config

This leaves use with a few options: 1) File an issue with pydantic asking if they could potentially relax the requirement that a generic class' type arguments provide a __get_validators__() fn and use the prototype above. This may be a non-starter for them if it breaks all sorts of logic within the code base 2) Abandon this feature 3) Add a __get_validators__() to SchemaModel. In my mind, this makes the most sense. If you envision a SchemaModel as a generalization of pydantic's BaseModel to data frames, then it makes sense if this class (and its subclasses) directly provide __get_validators__().

Sorry for the long-winded post but would love to hear your thoughts! .

jeffzi commented 3 years ago

Pydantic wants to "recursively" validate attributes, which makes sense in most cases.

I agree 3. is the way to go. In our case, we could attempt a SchemaModel.to_schema() to ensure the model definition can be compiled to a DataFrameSchema.

Just a detail, but I think DataFrame.validate could be renamed to _validate_pydantic() to avoid confusion for pandera users. We'd also need to add a test for that feature.

Feel free to submit a PR, that would give you proper credit for the work you've done so far :rocket:. You can assign the review to me. The contributing section can help you run the test suite/linters locally.

cosmicBboy commented 3 years ago

Hey @tfwillems just wanted to ping you about this feature, which I think would be very useful!

Let me know if you're still interested in making a PR and if there's anything we can do to help out!

cosmicBboy commented 3 years ago

hey @tfwillems friendly ping on this issue! Let me know if you're still interested in making a PR for this, or let @jeffzi or I know if you want to pass the baton, you can point us to a dev branch/fork with your most up-to-date prototype and perhaps one of us can take it to the finish line! 🏁

tfwillems commented 3 years ago

@cosmicBboy Sorry I've made so little progress here. I don't think I have the bandwidth at the moment to make a full-fledged PR, but here is all I think is really necessary:

import pandera as pa
from pandera.errors import SchemaInitError

class ExtendedSchemaModel(pa.SchemaModel):
    """Simple extension to pandera's SchemaModel that enables its use as a pydantic type"""

    @classmethod
    def __get_validators__(cls):
        yield cls._pydantic_validate

    @classmethod
    def _pydantic_validate(cls, df: Any) -> pd.DataFrame:
        """Verify that the input is a pandas data frame that meets all schema requirements"""
        if not isinstance(df, pd.DataFrame):
            raise ValueError("Expected a pandas data frame")

        try:
            cls.to_schema()
        except SchemaInitError as e:
            raise ValueError(
              f"Cannot use {cls.__name__} as a pydantic type as its SchemaModel cannot be converted to a DataFrameSchema.\n"
              f"Please revisit the model to address the following errors:\n{str(e)}"
           )

        try:
            return cls.validate(df)
        except pa.errors.SchemaError as e:
            raise ValueError(str(e))

Obviously the additional class is unnecessary and these functions would just be added to pa.SchemaModel's definition

jeffzi commented 2 years ago

Hi @tfwillems, no worries :+1: ! Thanks for sharing your work in progress, I'll take it from there.

jeffzi commented 2 years ago

This feature has been released with 0.8.0