dask / dask-expr

BSD 3-Clause "New" or "Revised" License
79 stars 18 forks source link

DataFrame subclass lost in `groupby.agg` with `split_out` set. #1024

Open TomAugspurger opened 2 months ago

TomAugspurger commented 2 months ago

Describe the issue:

As part of https://github.com/geopandas/dask-geopandas/pull/285, we found that dask-expr will lose the type of a pandas DataFrame subclass in groupby.agg if (and only if?) the split_out parameter is used.

Minimal Complete Verifiable Example:

Given this file:

```python # file: test.py import dask.dataframe.backends import pandas as pd import dask_expr as dx import dask.dataframe as dd from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty from dask.dataframe.core import get_parallel_type import dask.dataframe.backends dask.config.set(scheduler="single-threaded") class MySeries(pd.Series): @property def _constructor(self): return MySeries @property def _constructor_expanddim(self): return MyDataFrame class MyDataFrame(pd.DataFrame): @property def _constructor(self): return MyDataFrame @property def _constructor_sliced(self): return MySeries class MyIndex(pd.Index): ... class MyDaskSeries(dx.Series): _partition_type = MySeries class MyDaskDataFrame(dx.DataFrame): _partition_type = MyDataFrame class MyDaskIndex(dx.Index): _partition_type = MyIndex # Unclear if any of get_parallel_type and make_meta_dispatch are needed. # Reproduces with or without them. @get_parallel_type.register(MyDataFrame) def get_parallel_type_dataframe(df): return MyDataFrame @get_parallel_type.register(MySeries) def get_parallel_type_series(s): return MyDaskSeries @get_parallel_type.register(MyIndex) def get_parallel_type_index(ind): return MyDaskIndex @make_meta_dispatch.register(MyDataFrame) def make_meta_dataframe(df, index=None): return df.head(0) @make_meta_dispatch.register(MySeries) def make_meta_series(s, index=None): return s.head(0) @make_meta_dispatch.register(MyIndex) def make_meta_index(ind, index=None): return ind[:0] @meta_nonempty.register(MyDataFrame) def make_meta_nonempty_dataframe(x): return MyDataFrame(dask.dataframe.backends.meta_nonempty_dataframe(x)) df = dx.from_dict( {"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}, npartitions=4, constructor=MyDataFrame ) a = df.groupby("a").agg("first") b = df.groupby("a").agg("first", split_out=2) print("split-out=None", type(a.compute())) print("split-out=2 ", type(b.compute())) ```

running that produces

$ python test.py
split-out=None <class '__main__.MyDataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

I would expect the type there to be __main__.MyDataFrame regardless of split_out.

Anything else we need to know?:

Environment:

dask               2024.4.1
dask-expr          1.0.11

Edit: I made one addition to the script: adding a @meta_nonempty.register(MyDataFrame). I noticed that in DecomposableGroupbyAggregation.combine and DecomposableGroupbyAggregation.aggregate the types were regular pandas DataFrames, instead of the subclass.

Registering that meta_nonempty does keep it as MyDataFrame initially. I put some print statements in those methods to print the type of inputs[0] and type(_concat(inputs)) and get

combine <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

So initially we're OK, but by the time we do the final aggregate we've lost the subclass.

phofl commented 2 months ago

This is a shuffle issue (and also present on the current implementation if I am not mistaken?)

df.shuffle("a") will lose your type, that's what we do under the hood if split_out != 1. shuffle_method="tasks" keeps it, disk and p2p lose it.

I can patch that so that your resulting DataFrame will have the correct type, but I don't know if we can guarantee that we keep whatever you might add to the subclass through shuffles without you overriding the shuffle specific methods