unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.29k stars 307 forks source link

I can't test check functions decorated with @extensions.register_check_method() #1274

Open pvizan-artefact opened 1 year ago

pvizan-artefact commented 1 year ago

Question about pandera

I am trying to create tests for custom checks I have created using pandera. As shown in the code below, I write them as methods that I decorate with @extensions.register_check_method(). However, when I try to test the methods by importing them from an external file, the tests don't pass because they are NoneType. How could I test my custom check methods?

Code Sample, a copy-pastable example

Note: If you'd still like to submit a question, please read this guide detailing how to provide the necessary information for us to reproduce your question.

# Imports
import pytest
import pandas as pd
from pandera import extensions

# Custom checks
@extensions.register_check_method()
def no_duplicates(*, pandas_col):
    """Check there are no duplicates in column.

    Args:
        pandas_col: Panderas column object
    Returns:
        A pandas series of True or False values
    """
    return ~pandas_col.duplicated()

# Tests
@pytest.mark.parametrize(
    "data, expected_result",
    [
        (
            pd.Series([1, 2, 3, 4, 5]),
            True,
        ),
        (
            pd.Series([1, 2, 3, 3, 4, 5]),
            False,
        ),
    ],
)
def test_no_duplicates(data, expected_result):
    """Test.

    Args:
        data: data
        expected_result: true or false
    """
    result = no_duplicates(data)
    assert result.all() == expected_result

Expected behaviour

$ pytest quality_checks.py 
PASSED quality_checks.py::test_no_duplicates[data0-True]
PASSED quality_checks.py::test_no_duplicates[data1-True]

Actual behaviour

$ pytest quality_checks.py 
==================================================== test session starts ====================================================
platform linux -- Python 3.9.16, pytest-7.4.0, pluggy-1.2.0
rootdir: /home/jupyter/.../unit_tests
plugins: typeguard-4.0.0
collected 2 items                                                                                                           

quality_checks.py FF                                                                                                  [100%]

========================================================= FAILURES ==========================================================
______________________________________________ test_no_duplicates[data0-True] _______________________________________________

data = 0    1
1    2
2    3
3    4
4    5
dtype: int64, expected_result = True

    @pytest.mark.parametrize(
        "data, expected_result",
        [
            (
                pd.Series([1, 2, 3, 4, 5]),
                True,
            ),
            (
                pd.Series([1, 2, 3, 3, 4, 5]),
                False,
            ),
        ],
    )
    def test_no_duplicates(data, expected_result):
        """Test.

        Args:
            data: data
            expected_result: true or false
        """
>       result = no_duplicates(data)
E       TypeError: 'NoneType' object is not callable

quality_checks.py:38: TypeError
______________________________________________ test_no_duplicates[data1-False] ______________________________________________

data = 0    1
1    2
2    3
3    3
4    4
5    5
dtype: int64, expected_result = False

    @pytest.mark.parametrize(
        "data, expected_result",
        [
            (
                pd.Series([1, 2, 3, 4, 5]),
                True,
            ),
            (
                pd.Series([1, 2, 3, 3, 4, 5]),
                False,
            ),
        ],
    )
    def test_no_duplicates(data, expected_result):
        """Test.

        Args:
            data: data
            expected_result: true or false
        """
>       result = no_duplicates(data)
E       TypeError: 'NoneType' object is not callable

quality_checks.py:38: TypeError
================================================== short test summary info ==================================================
FAILED quality_checks.py::test_no_duplicates[data0-True] - TypeError: 'NoneType' object is not callable
FAILED quality_checks.py::test_no_duplicates[data1-False] - TypeError: 'NoneType' object is not callable
===================================================== 2 failed in 1.09s =====================================================

Additional context

The objects I am running the checks on are pa.Column, part of a pa.DataFrameSchema object.

pvizan-artefact commented 1 year ago

I found a workaround which consists in defining the checks and the function separately, in the following fashion.

def func_no_duplicates(pandas_col: PandasSeries):
    """Check there are no duplicates in column.

    Args:
        pandas_col: Panderas column object
    Returns:
        A pandas series of True or False values
    """
    return ~pandas_col.duplicated()

...

@extensions.register_check_method()
def no_duplicates(pandas_col):
    """Check there are no duplicates in column.

    Args:
        pandas_col: Panderas column object
    Returns:
        A pandas series of True or False values
    """
    return func_no_duplicates(pandas_col=pandas_col)

...

@pytest.mark.parametrize(
    "data, expected_result",
    [
        (
            pd.Series([1, 2, 3, 4, 5]),
            True,
        ),
        (
            pd.Series([1, 2, 3, 3, 4, 5]),
            False,
        ),
    ],
)
def test_no_duplicates(data, expected_result):
    """Test.

    Args:
        data: data
        expected_result: true or false
    """
    result = func_no_duplicates(data)
    assert result.all() == expected_result

Seems a bit more cumbersome than I would wish. Alternatively, I also found that you can create a pandera Column object within the test, catch the SchemaException when the data does not pass the check and return:

@extensions.register_check_method()
def no_duplicates(pandas_col):
    """Check there are no duplicates in column.

    Args:
        pandas_col: Panderas column object
    Returns:
        A pandas series of True or False values
    """
    return ~pandas_col.duplicated()

...

@pytest.mark.parametrize(
    "data, name, expected_result",
    [
        (
            pd.Series([1, 2, 3, 4, 5], name="sample_col"),
            "sample_col",
            True,
        ),
        (
            pd.Series([1, 2, 3, 3, 4, 5], name="sample_col"),
            "sample_col",
            False,
        ),
    ],
)
def test_no_duplicates(data, name_col, expected_result):
    """Test.

    Args:
        data: data
        name_col: name of column
        expected_result: true or false
    """
    column = pa.Column(dtype=float, coerce=True, checks=Check.no_duplicates(), name=name_col)
    try:
        column.validate(pd.DataFrame(data))
        assert expected_result == True
    except pa.errors.SchemaError as exc:
        logging.error(exc.failure_cases)
        assert expected_result == False

But I am not sure the way done above is correct since it is testing a whole schema, rather than the behaviour of the individual check.

JulianFerry commented 8 months ago

This is the solution I used (used a function with an input "statistic" to show how to set that up as well) :

# checks.py

import pandas as pd
from pandera import extensions

@extensions.register_check_method(statistics=["length"], supported_types=pd.DataFrame)
def length_equals(df: pd.DataFrame, *, length: int):
    return len(df) == length
# test_checks.py

import pytest
import pandera as pa
import pandas as pd
from checks import *  # required to register checks

@pytest.mark.parametrize(
    "df, length",
    [
        (pd.DataFrame({"foo": [1, 2, 3]}), 3),
        (pd.DataFrame({"foo": [1, 2], "bar": [3, 4]}), 2),
    ],
)
def test_length_equals(df: pd.DataFrame, length: int):
    schema = pa.DataFrameSchema(
        checks=pa.Check.length_equals(length=length),
    )
    schema.validate(df)

@pytest.mark.parametrize(
    "df, length",
    [
        (pd.DataFrame({"foo": [1, 2, 3]}), 2),
        (pd.DataFrame({"foo": [1, 2], "bar": [3, 4]}), 3),
    ],
)
def test_length_not_equals(df: pd.DataFrame, length: int):
    schema = pa.DataFrameSchema(
        checks=pa.Check.length_equals(length=length),
    )
    with pytest.raises(pa.errors.SchemaError):
        schema.validate(df)