wesselhuising / pandantic

Enriches the Pydantic BaseModel class by adding the ability to validate dataframes using the schema and custom validators of the same BaseModel class.
https://pandantic-rtd.readthedocs.io
MIT License
29 stars 3 forks source link

Support for Polars Dataframes #17

Open ldacey opened 1 year ago

ldacey commented 1 year ago

pandantic seemed to be such a nice and simple implementation that I decided edit your model to use with Polars Dataframes and figured I would share the results.

I only recently began using polars so there might be more efficient ways, but here were the changes I had to make to your model:

  1. There is no index, so replaced it using with_row_count() to get the row number for errors
  2. Chunk logic can be handled by iter_slices() where the n_rows can be determined by the total rows / CPU count
  3. Instead of to_dict(), we use iter_rows(named=True) to pass each row into the validator
  4. We use filter() to exclude the error rows if the errors is set to "filter"
from multiprocess import Process, Queue, cpu_count
import polars as pl
import math
import os
from pydantic import BaseModel
import logging

class PolarsModel(BaseModel):

    @classmethod
    def parse_df(
        cls,
        dataframe: pl.DataFrame,
        errors: str = "raise",
        context: dict[str, object] | None = None,
        n_jobs: int = 1,
        verbose: bool = True,
    ) -> pl.DataFrame:

        errors_index = []
        dataframe = dataframe.clone().with_row_count()

        logging.info(f"Validating {dataframe.height} rows")
        logging.debug(f"Amount of available cores: {cpu_count()}")

        if n_jobs != 1:
            if n_jobs < 0:
                n_jobs = cpu_count()

            chunk_size = math.ceil(len(dataframe) / n_jobs)
            chunks = list(dataframe.iter_slices(n_rows=chunk_size))
            total_chunks = len(chunks)

            logging.info(f"Split the dataframe into {total_chunks} chunks to process {chunk_size} rows per chunk.")

            processes = []
            q = Queue()

            for chunk in chunks:
                p = Process(target=cls._validate_row, args=(chunk, q, context, verbose), daemon=True)
                p.start()
                processes.append(p)

            num_stops = 0
            while num_stops < total_chunks:
                index = q.get()
                if index is None:
                    num_stops += 1
                else:
                    errors_index.append(index)

            for p in processes:
                p.join()

        else:
            for row in dataframe.iter_rows(named=True):
                try:
                    cls.model_validate(obj=row, context=context)
                except Exception as exc:
                    if verbose:
                        logging.info(f"Validation error found at row {row['row_nr']}\n{exc}")
                    errors_index.append(row["row_nr"])

        logging.info(f"# invalid rows: {len(errors_index)}")

        if len(errors_index) > 0 and errors == "raise":
            raise ValueError(f"{len(errors_index)} validation errors found in dataframe.")

        if len(errors_index) > 0 and errors == "filter":
            return dataframe.filter(~pl.col("row_nr").is_in(errors_index)).drop(columns=["row_nr"])

        return dataframe.drop(columns=["row_nr"])

    @classmethod
    def _validate_row(cls, chunk: pl.DataFrame, q: Queue, context=None, verbose=True) -> None:
        for row in chunk.iter_rows(named=True):
            try:
                cls.model_validate(obj=row, context=context)
            except Exception as exc:
                if verbose:
                    logging.info(f"Validation error found at row {row['row_nr']}\n{exc}")
                q.put(row["row_nr"])
        q.put(None)

I tested this on a dataframe which I duplicated a bunch of times until the row count was > 1 million rows to check if n_jobs was functioning correctly.

With n_jobs = 1:

image

With n_jobs = 4 (twice as fast):

image

Example validation error if verbose=True:

image

Example with errors="filter", the resulting dataframe has the expected rows:

image
wesselhuising commented 1 year ago

Hi @ldacey,

Thank you so much for trying this out. This is on the "private" roadmap, just after creating a benchmarking tests (as I would like to have a baseline performance for Pandantic versus any other package like Pandera). I think it is nice to have some kind of benchmarking tests, with a example dataframe that contains all the basic cases what you would expect from a regular dataframe (size, columns and different types). That also would help with improvement the performance of your methods, like parsing Polars DataFrames (by improving logic). That would also help with comparing performance between Pandas and Polars when using Pydantic for validation of your DataFrame.

I would actually have another approach, and create private methods within the current pandantic.BaseModel class. In the parse_df, logic should be created that determines the DataFrame type and reference the correct private method correctly.

Do you want to add a PR, I am open for contributes obviously. If you don't feel like it, that is also fine. I was planning to create the use of Polars DataFrame anyway in the future.

ldacey commented 1 year ago

Cool, yeah I did not know if you had any interest in other Dataframe libraries or what approach you would take. It seems like a majority of the code would be the same, with some differences in how chunking / row iteration gets done.

I was playing around earlier just to see what was possible.

try:
    test = ExampleModel.parse_df(df, errors="raise", n_jobs=4, verbose=False)
except ValidationException as e:
    print(e)
    print(e.errors)     

[2023-09-08T10:30:35.330+0800] {models.py:83} INFO - Validating 768 rows
49 validation errors found in dataframe.
[[{'type': 'greater_than_equal', 'loc': 'Employee CSAT', 'msg': 'Input should be greater than or equal to 3', 'input': 1, 'ctx': {'ge': 3}, 'row': 5}], [{'type': 'greater_than_equal', 'loc': 'Employee CSAT', 'msg': 'Input should be greater than or eq
ual to 3', 'input': 1, 'ctx': {'ge': 3}, 'row': 10}],

For example, I can technically read those errors into a Dataframe and then save it as a file or email an alert with the details etc. (This was mostly just for fun and I am not suggesting it should be added to your library)

    new = pl.from_dicts(e.flatten_errors())
    print(new.filter(pl.col("row").is_duplicated()))
image

Here is the latest version I was using:

import math
from typing import Any, Literal

import polars as pl
from multiprocess import Process, Queue, cpu_count
from pydantic import BaseModel, ValidationError
from pydantic_core import ErrorDetails

class ValidationException(Exception):
    """Exception raised when validation fails and returns a list of all errors"""

    def __init__(self, errors):
        super().__init__(f"{len(errors)} validation errors found in dataframe.")
        self.errors = errors

    def flatten_errors(self):
        """Flatten the list of error groups into a single list of errors"""
        return [
            {
                "row": error["row"],
                "loc": error["loc"],
                "input": error["input"],
                "type": error["type"],
                "ctx": error["ctx"],
                "msg": error["msg"],
            }
            for error_group in self.errors
            for error in error_group
        ]

def convert_errors(e: ValidationError, row_index: int) -> list[ErrorDetails]:
    """Removes the url field, adds the row field, and converts the loc field to a string
    in the ErrorDetails object

    Args:
        e: The ValidationError object
        row_index: The row number of the DataFrame that failed validation
    """
    new_errors: list[ErrorDetails] = []

    for error in e.errors():
        del error["url"]

        error["row"] = row_index

        if isinstance(error["loc"], (tuple, list)):
            error["loc"] = ".".join(map(str, error["loc"]))
        else:
            error["loc"] = str(error["loc"])

        new_errors.append(error)

    return new_errors

class PolarsBaseModel(BaseModel):
    """Base class for all Polars schema pydantic models"""

    @classmethod
    def parse_df(
        cls,
        dataframe: pl.DataFrame,
        errors: Literal["raise", "filter"] = "raise",
        context: dict[str, object] | None = None,
        n_jobs: int = 1,
        verbose: bool = True,
    ) -> pl.DataFrame:
        """Validate a DataFrame using the schema defined in the Pydantic model.
        Converts each row in the DataFrame to a dictionary prior to validation. If
        n_jobs > 1, the DataFrame is split into chunks and each chunk is validated in a
        separate process. Any errors are appended to a list for logging or filtering.

        Args:
            dataframe (pl.DataFrame): The DataFrame to validate.
            errors (str, optional): How to handle validation errors.
             Defaults to "raise".
            context (Optional[dict[str, Any], None], optional): The context to use for
             validation.
            n_jobs (int, optional): The number of processes to use for validation.
             Defaults to 1.
            verbose (bool, optional): Whether to log validation errors.
             Defaults to True.

        Returns:
            pl.DataFrame: The DataFrame with valid rows in case of errors="filter".
        """
        dataframe = dataframe.clone().with_row_count()
        errors_index = []
        error_details = []

        logging.info(f"Validating {dataframe.height} rows")
        logging.debug(f"Amount of available cores: {cpu_count()}")

        if n_jobs != 1:
            errors_index = cls._validate_multicore(
                dataframe, n_jobs, context, error_details
            )
        else:
            errors_index = cls._validate_singlecore(dataframe, context, error_details)

        if len(errors_index) > 0:
            if errors == "raise":
                if verbose:
                    logging.info(error_details)

                raise ValidationException(error_details)

            elif errors == "filter":
                if verbose:
                    logging.info(error_details)

                return dataframe.filter(~pl.col("row_nr").is_in(errors_index)).drop(
                    columns=["row_nr"]
                )

        return dataframe.drop(columns=["row_nr"])

    @classmethod
    def _validate_singlecore(
        cls,
        dataframe: pl.DataFrame,
        context: dict[str, Any] | None,
        error_details: list,
    ):
        """Validates each row of a DataFrame in dictionary format in a single process.

        Args:
            dataframe: DataFrame to validate
            context: Context to pass to the model_validate method
            error_details: List to store the error details
        """
        errors_index = []

        for row in dataframe.iter_rows(named=True):
            try:
                cls.model_validate(obj=row, context=context)

            except ValidationError as exc:
                exception = convert_errors(exc, row["row_nr"])
                error_details.append(exception)
                errors_index.append(row["row_nr"])

        return errors_index

    @classmethod
    def _validate_multicore(
        cls,
        dataframe: pl.DataFrame,
        n_jobs: int,
        context: dict[str, Any] | None,
        error_details: list,
    ):
        """Split the dataframe into chunks and validate each chunk in a separate process
        where each chunk is total rows / n_jobs in size. Each chunk is validated in a
        separate process. Any errors are appended to a list and returned.

        Args:
            dataframe: DataFrame to validate
            n_jobs: Number of processes to use
            context: Context to pass to the model_validate method
            error_details: List to store the error details
        """
        errors_index = []

        if n_jobs < 0:
            n_jobs = cpu_count()

        chunk_size = math.ceil(len(dataframe) / n_jobs)
        chunks = list(dataframe.iter_slices(n_rows=chunk_size))
        total_chunks = len(chunks)

        processes = []
        q = Queue()

        for chunk in chunks:
            p = Process(
                target=cls._validate_row,
                args=(chunk, q, context),
                daemon=True,
            )
            p.start()
            processes.append(p)

        num_stops = 0
        while num_stops < total_chunks:
            exceptions = q.get()
            if exceptions is None:
                num_stops += 1
            else:
                errors_index.append(exceptions["row_nr"])
                error_details.append(exceptions["error_detail"])

        for p in processes:
            p.join()

        return errors_index

    @classmethod
    def _validate_row(
        cls,
        chunk: pl.DataFrame,
        q: Queue,
        context: dict[str, Any] | None,
    ) -> None:
        """Validates each row of a DataFrame in dictionary format as a separate
        parallel process.

        Args:
            chunk: DataFrame chunk to validate
            q: Queue to store the row numbers of the rows that failed validation
            context: Context to pass to the model_validate method
            verbose: Whether to log the error details
            error_details: List to store the error details
        """
        for row in chunk.iter_rows(named=True):
            try:
                cls.model_validate(obj=row, context=context)

            except ValidationError as exc:
                exception = convert_errors(exc, row["row_nr"])

                q.put({"row_nr": row["row_nr"], "error_detail": exception})

        q.put(None)
xaviernogueira commented 1 month ago

@ldacey we are working on a refactor that will use the dependency injection design pattern to, eventually, handle as many combinations of table libraries (spark/pandas/polars/dask/etc) and "schema" libraries (pydantic/attrs/dataclass).

It seems like you put some real thought in this and are a polars user. We are starting with pandas as it has the widest use base, but if you are interested in this new round of effort, we would love if you join (perhaps for our next call) and help with the polars implementation or at least provide perspective of what polars users might like.