unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.05k stars 281 forks source link

Custom DTypes With Polars #1587

Closed Filimoa closed 1 month ago

Filimoa commented 1 month ago

Describe the bug A clear and concise description of what the bug is.

I'm not sure if this is a bug, intentional or just missing documentation.

Code Sample, a copy-pastable example

from pandera import dtypes
from pandera.engines import numpy_engine, pandas_engine
from pandera.typing import Series, Bool
from pandas.api.types import infer_dtype
import pandas as pd
import pandera as pa_core

@numpy_engine.Engine.register_dtype
@dtypes.immutable
class LiteralFloat(numpy_engine.Float64):
    def coerce(self, series: Series) -> Series[Bool]:
        """If comes across a string, coerce it to a float. If it fails, return NaN."""
        if "string" in infer_dtype(series):
            series = series.apply(
                lambda x: str.replace(x, ",", "") if isinstance(x, str) else x
            )

        return pd.to_numeric(series, errors="coerce")

class Schema(pa_core.SchemaModel):
    state: str
    price: Series[LiteralFloat]

    class Config:
        strict = "filter"
        coerce = True

df = pd.DataFrame(
    {
        "state": ["FL", "FL", "FL", "CA", "CA", "CA"],
        "price": ["8,000", "12.0", "10.0", "16.0", "20.0", "18.0"],
    }
)

Schema.validate(df)

With the pandas API this was possible - you could write some custom dtypes that perform some basic data cleaning. For example, in our case we had a YesNoBool that coerces -> "yes" / "no" to booleans. This was handy since we deal with hundreds of these columns and it's a pain to write transformation logic for each one.

The documentation is pretty vague on this (not sure if this is an anti-pattern) but this was my best attempt on porting the code to polars.

import pandera.polars as pa
import polars as pl
from pandera.engines import polars_engine
from pandera import dtypes
from pandera.typing import Series

@polars_engine.Engine.register_dtype
@dtypes.immutable
class LiteralFloat(pl.Float64):
    def coerce(self, series):
        """If comes across a string, remove commas and coerce it to a float. If it fails, return NaN."""
        series = series.str.replace(",", "").cast(pl.Float64, strict=False)
        return series

class Schema(pa.DataFrameModel):
    city: str
    price: Series[LiteralFloat] = pa.Field(coerce=True)

    class Config:
        strict = "filter"
        coerce = True

dl = pl.from_pandas(df)
Schema.validate(dl)
>>> SchemaInitError: Invalid annotation 'price: pandera.typing.pandas.Series[__main__.LiteralFloat]'

Is this intentional?

Desktop (please complete the following information):

Screenshots

None

Additional context

I'll be glad to open a PR to update the docs if this is just a docs issue.

cosmicBboy commented 1 month ago

Looks like a bug in pandera.typing.Series... I think you can try just the bare type and it should work:

class Schema(pa.DataFrameModel):
    city: str
    price: LiteralFloat = pa.Field(coerce=True)

The correct implementation of the custom dtype is also:

from pandera.api.polars.types import PolarsData

@polars_engine.Engine.register_dtype
@dtypes.immutable
class LiteralFloat(polars_engine.Float64):  # πŸ‘ˆ  inherit from polars_engine.Float64, not the polars dtype
    def coerce(self, polars_data: PolarsData) -> pl.LazyFrame:  # πŸ‘ˆ note the input and output signature
        """If comes across a string, remove commas and coerce it to a float. If it fails, return NaN."""
        return polars_data.lazyframe.with_columns(  # πŸ‘ˆ must return a lazyframe
            pl.col(polars_data.key)
            .str.replace(",", "")
            .cast(pl.Float64, strict=False)
        )

See the polars engine DataType implementation for details on the signatures of these methods: https://github.com/unionai-oss/pandera/blob/main/pandera/engines/polars_engine.py#L91

I'll look into fixing the SchemaInitError: Invalid annotation 'price: pandera.typing.pandas.Series[__main__.LiteralFloat]' issue, if you can, would be great if the polars docs can be updated with an example of a custom datatype: https://github.com/unionai-oss/pandera/blob/main/docs/source/polars.md

cosmicBboy commented 1 month ago

I'll look into fixing the SchemaInitError: Invalid annotation 'price: pandera.typing.pandas.Series[main.LiteralFloat]' issue

So the whole Series[TYPE] syntax is only supported in the pandas DataFrameModel and will be deprecated in that API eventually... looking forward to new backends (in this case polars) the more concise bare type will be supported. I'll add a more informative error message here.

Filimoa commented 1 month ago

That worked, I'll open a PR shortly!