pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
30.61k stars 1.99k forks source link

panic calling `collect_schema` on lazy group_by + map_batches #17327

Open wence- opened 5 months ago

wence- commented 5 months ago

Checks

Reproducible example

import polars as pl

df = pl.LazyFrame({"a": [1, 1, 1], "b": [2, 3, 4]})

q = df.group_by("a").agg(pl.col("b").map_batches(lambda s: s))

print(q.collect_schema())

Log output

No response

Issue description

When calling map_batches without providing a return_dtype in a grouped context, the resulting dtype for the batch is inferred by looking at the first value. Without doing this, it is therefore Unknown. When we ask for collect_schema(), however, the grouped column's schema will be List(Unknown) and parse_into_dtype will be called via the List dtype constructor on the python side. This raises TypeError (by default) for Unknown values.

On the rust conversion side, we call unwrap on this error result and therefore get a (difficult to catch) PanicException from pyo3:

thread '<unnamed>' panicked at py-polars/src/conversion/mod.rs:241:39:
called `Result::unwrap()` on an `Err` value: PyErr { type: <class 'TypeError'>, value: TypeError("cannot parse input of type 'Unknown' into Polars data type: Unknown"), traceback: Some(<traceback object at 0x7209e5244640>) }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Expected behavior

No panic, and a re-raised TypeError.

Perhaps also one could accept returning a schema that has List(Unknown) as the dtype for the column.

Installed versions

``` --------Version info--------- Polars: 1.0.0-rc.2 Index type: UInt32 Platform: Linux-6.5.0-41-generic-x86_64-with-glibc2.35 Python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] ----Optional dependencies---- adbc_driver_manager: 1.0.0 cloudpickle: 3.0.0 connectorx: 0.3.3 deltalake: 0.18.1 fastexcel: 0.10.4 fsspec: 2024.6.0 gevent: 24.2.1 great_tables: 0.9.0 hvplot: 0.10.0 matplotlib: 3.9.0 nest_asyncio: 1.6.0 numpy: 1.26.4 openpyxl: 3.1.4 pandas: 2.2.2 pyarrow: 16.1.0 pydantic: 2.7.4 pyiceberg: sqlalchemy: 2.0.30 torch: 2.3.1.post300 xlsx2csv: 0.8.2 xlsxwriter: 3.2.0 ```
wence- commented 5 months ago

One could support this on the python side with something like:

diff --git a/py-polars/polars/datatypes/_parse.py b/py-polars/polars/datatypes/_parse.py
index 55345909c..fcbdd9376 100644
--- a/py-polars/polars/datatypes/_parse.py
+++ b/py-polars/polars/datatypes/_parse.py
@@ -37,7 +37,7 @@ else:  # pragma: no cover
     UnionType = UnionTypeOld

-def parse_into_dtype(input: Any) -> PolarsDataType:
+def parse_into_dtype(input: Any, *, include_unknown: bool = False) -> PolarsDataType:
     """
     Parse an input into a Polars data type.

@@ -46,7 +46,7 @@ def parse_into_dtype(input: Any) -> PolarsDataType:
     TypeError
         If the input cannot be parsed into a Polars data type.
     """
-    if is_polars_dtype(input):
+    if is_polars_dtype(input, include_unknown=include_unknown):
         return input
     elif isinstance(input, ForwardRef):
         return _parse_forward_ref_into_dtype(input)
diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py
index 08aeb53c5..68cc7be18 100644
--- a/py-polars/polars/datatypes/classes.py
+++ b/py-polars/polars/datatypes/classes.py
@@ -604,7 +604,7 @@ class List(NestedType):
     inner: PolarsDataType

     def __init__(self, inner: PolarsDataType | PythonDataType):
-        self.inner = polars.datatypes.parse_into_dtype(inner)
+        self.inner = polars.datatypes.parse_into_dtype(inner, include_unknown=True)

     def __eq__(self, other: PolarsDataType) -> bool:  # type: ignore[override]
         # This equality check allows comparison of type classes and type instances.
@@ -675,7 +675,7 @@ class Array(NestedType):
             msg = "Array constructor is missing the required argument `shape`"
             raise TypeError(msg)

-        inner_parsed = polars.datatypes.parse_into_dtype(inner)
+        inner_parsed = polars.datatypes.parse_into_dtype(inner, include_unknown=True)
         inner_shape = inner_parsed.shape if isinstance(inner_parsed, Array) else ()

         if isinstance(shape, int):
@@ -754,7 +754,7 @@ class Field:

     def __init__(self, name: str, dtype: PolarsDataType):
         self.name = name
-        self.dtype = polars.datatypes.parse_into_dtype(dtype)
+        self.dtype = polars.datatypes.parse_into_dtype(dtype, include_unknown=True)

     def __eq__(self, other: Field) -> bool:  # type: ignore[override]
         return (self.name == other.name) & (self.dtype == other.dtype)