unionai-oss / pandera

A light-weight, flexible, and expressive statistical data testing library
https://www.union.ai/pandera
MIT License
3.37k stars 310 forks source link

[BUG]: use `LazyFrame.collect_schema().names()` over `LazyFrame.columns` #1744

Closed jjfantini closed 3 months ago

jjfantini commented 3 months ago

Describe the bug There is a perfromance warning in Polars 1.1.0 that using data.columns unecessarily collects the dataframe. It is better to use LazyFrame.collect_schema().names() to get names of the DF/LF.

Note: Please read this guide detailing how to provide the necessary information for us to reproduce your bug.

Code Sample, a copy-pastable example

The fix should be applied in backend.polars.container/components/base

    def collect_column_info(self, check_obj: pl.LazyFrame, schema):
        """Collect column metadata for the dataframe."""
        column_names: List[Any] = []
        absent_column_names: List[Any] = []
        regex_match_patterns: List[Any] = []

        for col_name, col_schema in schema.columns.items():
            if (
                not col_schema.regex
                and col_name not in check_obj.columns
                and col_schema.required
            ):
                absent_column_names.append(col_name)

            if col_schema.regex:
                try:
                    column_names.extend(
                        col_schema.get_backend(check_obj).get_regex_columns(
                            col_schema, check_obj
                        )
                    )
                    regex_match_patterns.append(col_schema.selector)
                except SchemaError:
                    pass
            elif col_name in check_obj.columns:
                column_names.append(col_name)

        # drop adjacent duplicated column names
        destuttered_column_names = [*check_obj.columns]

        return ColumnInfo(
            sorted_column_names=dict.fromkeys(column_names),
            expanded_column_names=frozenset(column_names),
            destuttered_column_names=destuttered_column_names,
            absent_column_names=absent_column_names,
            regex_match_patterns=regex_match_patterns,
        )

Expected behavior

A clear and concise description of what you expected to happen.

Desktop (please complete the following information):

Screenshots

If applicable, add screenshots to help explain your problem.

Additional context

Add any other context about the problem here.

jjfantini commented 3 months ago

This makes pytest tests fail if it is using a Pandera.BaseModel for validation

etf_data_samples = [<LazyFrame at 0x30C0D9BB0>, <LazyFrame at 0x30C27BD40>, <LazyFrame at 0x30C27BBC0>, <LazyFrame at 0x30C27B380>]

    @pytest.mark.asyncio()
    async def test_aget_asset_class_filter(etf_data_samples):
        """
        Test the aget_asset_class_filter function with various ETF data samples.

        This function also inherently tests the `normalize_asset_class` function.
        """
        for sample in etf_data_samples:
>           etf_data = ETFCategoryData(sample)

tests/unittests/portfolio/analytics/user_table/test_helpers.py:115:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
menv/lib/python3.12/site-packages/pandera/api/dataframe/model.py:138: in __new__
    DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs)
menv/lib/python3.12/site-packages/pandera/api/dataframe/model.py:289: in validate
    cls.to_schema().validate(
menv/lib/python3.12/site-packages/pandera/api/polars/container.py:58: in validate
    output = self.get_backend(check_obj).validate(
menv/lib/python3.12/site-packages/pandera/backends/polars/container.py:46: in validate
    column_info = self.collect_column_info(check_obj, schema)
menv/lib/python3.12/site-packages/pandera/backends/polars/container.py:214: in collect_column_info
    and col_name not in check_obj.columns
menv/lib/python3.12/site-packages/polars/lazyframe/frame.py:448: in columns
    issue_warning(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

message = 'Determining the column names of a LazyFrame requires resolving its schema, which is a potentially expensive operation. Use `LazyFrame.collect_schema().names()` to get the column names without this warning.'
category = <class 'polars.exceptions.PerformanceWarning'>, kwargs = {}

    def issue_warning(message: str, category: type[Warning], **kwargs: Any) -> None:
        """
        Issue a warning.

        Parameters
        ----------
        message
            The message associated with the warning.
        category
            The warning category.
        **kwargs
            Additional arguments for `warnings.warn`. Note that the `stacklevel` is
            determined automatically.
        """
>       warnings.warn(
            message=message, category=category, stacklevel=find_stacklevel(), **kwargs
        )
E       polars.exceptions.PerformanceWarning: Determining the column names of a LazyFrame requires resolving its schema, which is a potentially expensive operation. Use `LazyFrame.collect_schema().names()` to get the column names without this warning.

menv/lib/python3.12/site-packages/polars/_utils/various.py:444: PerformanceWarning
------------------- generated xml file: /Users/jjfantini/github/humblFINANCE-org/humblDATA/reports/pytest.xml -------------------
==================================================== short test summary info ====================================================
FAILED tests/unittests/portfolio/analytics/user_table/test_helpers.py::test_aget_asset_class_filter[lazyframe] - polars.exceptions.PerformanceWarning: Determining the column names of a LazyFrame requires resolving its schema, which is a potentially expensive operation. Use `LazyFrame.collect_schema().names()` to get the column names without this warning.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
============================================ 1 failed, 16 passed, 3 skipped in 6.23s ============================================

Anyone wondering, you need to add filterwarnings = ["ignore::polars.exceptions.PerformanceWarning"] to your pyproject.toml

cosmicBboy commented 3 months ago

thanks for reporting this @jjfantini, #1746 should resolve this

jjfantini commented 3 months ago

Okay great thanks! WIll wait until next release :)