unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.34k stars 308 forks source link

str_length function not working in pa.Field for PySpark #1311

Open karutyunov opened 1 year ago

karutyunov commented 1 year ago

Describe the bug

When trying to use the str_length function in pa.Field to validate the length of a string, we get a NotImplementedError every time. I tried passing arguments in different ways, as in the screenshot, and in the form of str_length(1, 2), both options give the same error

Note: Please read this guide detailing how to provide the necessary information for us to reproduce your bug.

Code Sample, a copy-pastable example

import pandera.pyspark as pa
import pyspark.sql.types as T
import pyspark.sql.functions as F

from decimal import Decimal
from pyspark.sql import SparkSession
from pandera.pyspark import DataFrameModel

spark = SparkSession.builder.getOrCreate()

class PanderaSchema(DataFrameModel):
    id: T.IntegerType() = pa.Field(gt=4)
    product_name: T.StringType() = pa.Field(str_length={"min_value": 1, "max_value": 2}, coerce=True)
    price: T.DecimalType(20, 5) = pa.Field()
    description: T.ArrayType(T.StringType()) = pa.Field()
    meta: T.MapType(T.StringType(), T.StringType()) = pa.Field()

data = [
    (5, "Bread", Decimal(44.4), ["description of product"], {"product_category": "dairy"}),
    (15, "Butter", Decimal(99.0), ["more details here"], {"product_category": "bakery"}),
]

spark_schema = T.StructType(
    [
        T.StructField("id", T.IntegerType(), False),
        T.StructField("product_name", T.StringType(), False),
        T.StructField("price", T.DecimalType(20, 5), False),
        T.StructField("description", T.ArrayType(T.StringType(), True), False),
        T.StructField(
            "meta", T.MapType(T.StringType(), T.StringType(), True), False
        ),
    ],
)
df = spark.createDataFrame(data, spark_schema)

import json
df_out = PanderaSchema.validate(check_obj=df)

df_out_errors = df_out.pandera.errors
print(json.dumps(dict(df_out_errors), indent=4))

Expected behavior

We expect a successful or unsuccessful test of the str_length function (validation error), but we get an error

/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pyspark/pandas/__init__.py:48: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.
  "'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to "
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/14 18:37:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
{
    "DATA": {
        "CHECK_ERROR": [
            {
                "schema": "PanderaSchema",
                "column": "product_name",
                "check": "str_length(1, 2)",
                "error": "Error while executing check function: NotImplementedError()\nTraceback (most recent call last):\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/backends/pyspark/components.py\", line 135, in run_checks\n    check_obj, schema, check, check_index, *check_args\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/backends/pyspark/base.py\", line 85, in run_check\n    check_result = check(check_obj, *args)\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/api/checks.py\", line 229, in __call__\n    return backend(check_obj, column)\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/backends/pyspark/checks.py\", line 110, in __call__\n    check_obj, key, self.check._check_kwargs\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/multimethod/__init__.py\", line 407, in __call__\n    return self[sig](*args, **kwargs)\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/backends/pyspark/checks.py\", line 79, in apply\n    return self.check._check_fn(check_obj_and_col_name, **kwargs)\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/multimethod/__init__.py\", line 371, in __call__\n    return func(*args, **kwargs)\n  File \"/Users/ka/change-devices-propensity/venv/lib/python3.7/site-packages/pandera/backends/base/builtin_checks.py\", line 92, in str_length\n    raise NotImplementedError\nNotImplementedError\n"
            }
        ]
    }
}

Desktop (please complete the following information):

Additional context

As part of the tests, I decided to try the in_range function, because it has the same argument passing syntax - it works flawlessly

marrov commented 4 months ago

I've coded this quickly to solve this with a custom registration of a builtin check:

import pyspark.sql.types as T

from typing import Optional
from pyspark.sql import functions as F
from pandera.api.extensions import register_builtin_check
from pandera.backends.pyspark.utils import convert_to_list
from pandera.api.pyspark.types import PysparkDataframeColumnObject
from pandera.backends.pyspark.decorators import register_input_datatypes

@register_builtin_check(error="str_length({min_value}, {max_value})")
@register_input_datatypes(acceptable_datatypes=convert_to_list(T.StringType))
def str_length(
    data: PysparkDataframeColumnObject,
    min_value: Optional[int] = None,
    max_value: Optional[int] = None,
) -> bool:
    """Ensure that the length of strings in a column is within a specified range."""
    if min_value is None and max_value is None:
        raise ValueError("Must provide at least one of 'min_value' and 'max_value'")
    str_len = F.length(F.col(data.column_name))
    cond = F.lit(True)
    if min_value is not None:
        cond = cond & (str_len >= min_value)
    if max_value is not None:
        cond = cond & (str_len <= max_value)

    return data.dataframe.filter(~cond).limit(1).count() == 0

This should be added to the pyspark.sql builtin checks