unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.05k stars 281 forks source link

Parametrized type annotations are broken for polars DataFrameModels #1580

Closed r-bar closed 3 weeks ago

r-bar commented 1 month ago

Description

Pandera DataFrameModels do not support parameterized types for polars, while DataFrameSchemas do.

Example

Here is an example of a working DataFrameSchema and several variations of broken DataFrameModels.

from typing import Annotated

import pandera.polars as pa
import polars as pl

from pandera.typing import Series
from pandera.errors import SchemaInitError

df = pl.DataFrame({
    "id": [1, 2, 3],
    "lists": [["a"], ["a", "b"], ["a", "b", "c"]],
})

# works!
schema = pa.DataFrameSchema(
    columns={
        "id": pa.Column(int),
        "lists": pa.Column(list[str]),
    }
)
schema.validate(df)
print("DataFrameSchema validation passed")

class Lists(pa.DataFrameModel):
    """Most basic, expected form given the working schema above."""
    id: int
    lists: list[str]

try:
    Lists.validate(df)
except SchemaInitError as e:
    print("\nLists validation failed")
    print(e)
else:
    print("\nLists validation passed")

class ListsSeries(pa.DataFrameModel):
    """Using series as a wrapper around basic data types like the id column here
    will not work. Examples of this appear in the pandera documentation.
    https://pandera.readthedocs.io/en/latest/dataframe_models.html#dtype-aliases
    """
    id: Series[int]
    lists: Series[list[str]]

try:
    ListsSeries.validate(df)
except SchemaInitError as e:
    print("\nListsSeries validation failed")
    print(e)
else:
    print("\nListsSeries validation passed")

class AlternateListsSeries(pa.DataFrameModel):
    """Demonstrating using Series as a type wrapper around only lists to avoid
    the initialization error on id."""
    id: int
    lists: Series[list[str]]

try:
    AlternateListsSeries.validate(df)
except SchemaInitError as e:
    print("\nAlternateListsSeries validation failed")
    print(e)
else:
    print("\nAlternateListsSeries validation passed")

class ListsAnnotated(pa.DataFrameModel):
    """Parameterized form using Annotated as suggested at
    https://pandera.readthedocs.io/en/latest/polars.html#nested-types
    """
    id: int
    lists: Series[Annotated[list, str]]

try:
    ListsAnnotated.validate(df)
except TypeError as e:
    print("\nListsAnnotated validation failed")
    print(e)
else:
    print("\nListsAnnotated validation passed")

class ListsAnnotatedStr(pa.DataFrameModel):
    """Alternate parameterized form using Annotated as seen in the examples here:
    https://pandera.readthedocs.io/en/latest/dataframe_models.html#annotated
    """
    id: int
    lists: Series[Annotated[list, "str"]]

try:
    ListsAnnotatedStr.validate(df)
except TypeError as e:
    print("\nListsAnnotatedStr validation failed")
    print(e)
else:
    print("\nListsAnnotatedStr validation passed")

When run with the following python / library versions:

the above script produces:

DataFrameSchema validation passed

Lists validation failed
Invalid annotation 'lists: list[str]'

ListsSeries validation failed
Invalid annotation 'id: pandera.typing.pandas.Series[int]'

AlternateListsSeries validation failed
Invalid annotation 'lists: pandera.typing.pandas.Series[list[str]]'

ListsAnnotated validation failed
Annotation 'Annotated' requires all positional arguments ['args', 'kwargs'].

ListsAnnotatedStr validation failed
Annotation 'Annotated' requires all positional arguments ['args', 'kwargs'].

Expected behavior

I would expect any column types that are valid to pass to DataFrameSchema's constructor to also be valid as annotations for DataFrameModel.

Desktop (please complete the following information):

cosmicBboy commented 4 weeks ago

Thanks for reporting this @r-bar FYI Series[Type] annotations is currently not supported in the polars API, see https://github.com/unionai-oss/pandera/pull/1588 and ongoing discussion here: https://github.com/unionai-oss/pandera/issues/1594.

Looking into this, planning on supporting:

class Lists(pa.DataFrameModel):
    """Most basic, expected form given the working schema above."""
    id: int
    lists: list[str]