pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
30.22k stars 1.95k forks source link

`group_by(explode())` produces different output than `explode().group_by()` #15984

Open lightningboltemoji opened 6 months ago

lightningboltemoji commented 6 months ago

Checks

Reproducible example

import polars as pl

df = pl.DataFrame({"a":[1], "b":[[2,3,4]], "c":[5]})
print(f"{df}\n")
print(f"{df.group_by('a', pl.col('b').explode()).sum()}\n")
print(f"{df.explode('b').group_by('a', 'b').sum()}")

Log output

% POLARS_VERBOSE=1 python3 report.py
shape: (1, 3)
┌─────┬───────────┬─────┐
│ a   ┆ b         ┆ c   │
│ --- ┆ ---       ┆ --- │
│ i64 ┆ list[i64] ┆ i64 │
╞═════╪═══════════╪═════╡
│ 1   ┆ [2, 3, 4] ┆ 5   │
└─────┴───────────┴─────┘

keys/aggregates are not partitionable: running default HASH AGGREGATION
shape: (1, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 2   ┆ 5   │
└─────┴─────┴─────┘

DATAFRAME < 1000 rows: running default HASH AGGREGATION
shape: (3, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 3   ┆ 5   │
│ 1   ┆ 4   ┆ 5   │
│ 1   ┆ 2   ┆ 5   │
└─────┴─────┴─────┘

Issue description

When using explode inside of a group_by, I was expecting to get the same result as exploding then grouping. Instead, it feels like only the first element of b was exploded (2) and the other values (3, 4) were dropped.

I noticed that my output shrunk while refactoring and traced it back to this. No reason for using one over the other. Sorry in advance if I'm thinking about this wrong.

Expected behavior

Both forms produce the bottom dataframe (referring to log output).

Installed versions

``` --------Version info--------- Polars: 0.20.23 Index type: UInt32 Platform: macOS-14.4.1-arm64-arm-64bit Python: 3.12.3 (main, Apr 9 2024, 08:09:14) [Clang 15.0.0 (clang-1500.3.9.4)] ----Optional dependencies---- adbc_driver_manager: cloudpickle: connectorx: deltalake: fastexcel: fsspec: gevent: hvplot: matplotlib: nest_asyncio: numpy: openpyxl: pandas: pyarrow: pydantic: pyiceberg: pyxlsb: sqlalchemy: xlsx2csv: xlsxwriter: ```
cmdlineluser commented 6 months ago

Should this raise a ShapeError?

If the explode is the only group key, it raises:

df.group_by(pl.col("b").explode()).all()
# ShapeError: series used as keys should have the same length as the DataFrame
deanm0000 commented 6 months ago

Should this raise a ShapeError?

I think so.