unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.12k stars 294 forks source link

Pydantic compatibility issue #1677

Open riziles opened 3 weeks ago

riziles commented 3 weeks ago

I believe that the latest versions of Pydantic and Pandera are not fully compatible.

This relates to https://github.com/unionai-oss/pandera/issues/1395 which was closed, but I think should still be open

This code throws an error:

import pandas as pd
import pandera as pa
from pandera.typing import DataFrame, Series
import pydantic

class SimpleSchema(pa.DataFrameModel):
    str_col: Series[str] = pa.Field(unique=True)

class PydanticModel(pydantic.BaseModel):
    x: int
    df: DataFrame[SimpleSchema]

print(PydanticModel.model_json_schema())

error message:

Exception has occurred: PydanticInvalidForJsonSchema
Cannot generate a JsonSchema for core_schema.PlainValidatorFunctionSchema ({'type': 'no-info', 'function': functools.partial(<bound method DataFrame.pydantic_validate of <class 'pandera.typing.pandas.DataFrame'>>, schema_model=SimpleSchema)})

For further information visit https://errors.pydantic.dev/2.7/u/invalid-for-json-schema
  File "C:\LocalTemp\Repos\RA\RiskCalcs\scratch.py", line 18, in <module>
    print(PydanticModel.model_json_schema())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pydantic.errors.PydanticInvalidForJsonSchema: Cannot generate a JsonSchema for core_schema.PlainValidatorFunctionSchema ({'type': 'no-info', 'function': functools.partial(<bound method DataFrame.pydantic_validate of <class 'pandera.typing.pandas.DataFrame'>>, schema_model=SimpleSchema)})

For further information visit https://errors.pydantic.dev/2.7/u/invalid-for-json-schema

I have tried various config options to get around this error to no avail.

riziles commented 3 weeks ago

Here is my real hacky workaround (no idea if it is right):

import pandas as pd
import pandera as pa
from pandera.typing import DataFrame as _DataFrame, Series

from pydantic_core import core_schema, CoreSchema
from pydantic import GetCoreSchemaHandler, BaseModel
from typing import TypeVar, Generic, Any

T = TypeVar("T")  

class DataFrame(_DataFrame, Generic[T]):

    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> CoreSchema:

        schema = source_type().__orig_class__.__args__[0].to_schema()

        type_map = {
            "str": core_schema.str_schema(),
            "int64": core_schema.int_schema(),
            "float64": core_schema.float_schema(),
            "bool": core_schema.bool_schema(),
            "datetime64[ns]": core_schema.datetime_schema()
        }

        return core_schema.list_schema(
            core_schema.typed_dict_schema(
                {
                    i:core_schema.typed_dict_field(type_map[str(j.dtype)]) for i,j in schema.columns.items()
                },
            )
        )

class SimpleSchema(pa.DataFrameModel):
    str_col: Series[str]

class PydanticModel(BaseModel):
    x: int
    df: DataFrame[SimpleSchema]