pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.34k stars 1.86k forks source link

Allow nested / chained aggregations #14361

Open DrMaphuse opened 7 months ago

DrMaphuse commented 7 months ago

Description

Similar to #12051, but I want to emphasize the need for nested aggregations in group_by or over.

Sometimes, it is necessary to combine different aggregations to achieve a desired outcome.

Currently, when trying to chain window functions or use a window function in a groupby, we get

InvalidOperationError: window expression not allowed in aggregation

This results in computations that cannot be done in a single expression, even though in principle it seems like this should be possible.

It would be nice to be able to write these types of calculations without having to chain multiple "with_columns()" or "group_by()", which can get pretty ugly and also makes it more difficult to use expressions programmatically.

Example:

import polars as pl

# Sample data
data = {
    "A": ["bar", "bar", "bar", "foo", "foo", "foo", "foo", "foo"],
    "B": ["one", "three", "two", "one", "one", "three", "two", "two"],
    "C": [2, 4, 6, 1, 7, 8, 3, 5],
}

# Create DataFrame
df = pl.DataFrame(data)

print(df)
shape: (8, 3)
┌─────┬───────┬─────┐
│ A   ┆ B     ┆ C   │
│ --- ┆ ---   ┆ --- │
│ str ┆ str   ┆ i64 │
╞═════╪═══════╪═════╡
│ bar ┆ one   ┆ 2   │
│ bar ┆ three ┆ 4   │
│ bar ┆ two   ┆ 6   │
│ foo ┆ one   ┆ 1   │
│ foo ┆ one   ┆ 7   │
│ foo ┆ three ┆ 8   │
│ foo ┆ two   ┆ 3   │
│ foo ┆ two   ┆ 5   │
└─────┴───────┴─────┘
# Does not work:
print(df.with_columns(pl.col("C").max().over("A", "B").min().over("A").alias("E")))

InvalidOperationError: window expression not allowed in aggregation

# works
print(
    df.with_columns(pl.col("C").max().over("A", "B").alias("_")).with_columns(
        pl.col("_").min().over("A").alias("E")
    )
)

shape: (8, 5)
┌─────┬───────┬─────┬─────┬─────┐
│ A   ┆ B     ┆ C   ┆ _   ┆ E   │
│ --- ┆ ---   ┆ --- ┆ --- ┆ --- │
│ str ┆ str   ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═══════╪═════╪═════╪═════╡
│ bar ┆ one   ┆ 2   ┆ 2   ┆ 2   │
│ bar ┆ three ┆ 4   ┆ 4   ┆ 2   │
│ bar ┆ two   ┆ 6   ┆ 6   ┆ 2   │
│ foo ┆ one   ┆ 1   ┆ 7   ┆ 5   │
│ foo ┆ one   ┆ 7   ┆ 7   ┆ 5   │
│ foo ┆ three ┆ 8   ┆ 8   ┆ 5   │
│ foo ┆ two   ┆ 3   ┆ 5   ┆ 5   │
│ foo ┆ two   ┆ 5   ┆ 5   ┆ 5   │
└─────┴───────┴─────┴─────┴─────┘
evbo commented 1 month ago

Hopefully I'm not conflating issues here, but it appears chained aggregations can also silently fail.

In my odd/contrived example below, I would expect end and end_expected to have the same values, but they don't. The only difference is end_expected is broken up and evaluated across two separate with_columns. My suspicion is the optimizer is sensing a redundancy and dropping the backfill when it's evaluated as a single expression:

    lf = pl.DataFrame(
        {
            "time": [
                1.01,
                1.02,
                1.03,
                1.04,
                1.04,
                1.65,
                1.69,
                1.71,
                1.77,
                1.88,
                1.96,
                1.97,
                1.98,
            ],
            "condition": [
                None,
                "connect",
                None,
                "player_quit",
                "player_quit",
                None,
                None,
                "connect",
                None,
                "player_quit",
                "connect",
                None,
                "player_quit",
            ],
        }
    ).lazy()

    start = pl.when(condition="player_quit").then("time").backward_fill()

    with pl.Config(tbl_rows=-1):
        print(
            lf.with_columns(
                [
                    start.alias("start"),
                    start.shift(-1).alias("shifted"),
                    start.shift(-1).max().over("time").alias("end"),
                ]
            )
            .with_columns(pl.col("shifted").max().over("time").alias("end_expected"))
            .collect()
        )

Result:

┌──────┬─────────────┬───────┬─────────┬──────┬──────────────┐
│ time ┆ condition   ┆ start ┆ shifted ┆ end  ┆ end_expected │
│ ---  ┆ ---         ┆ ---   ┆ ---     ┆ ---  ┆ ---          │
│ f64  ┆ str         ┆ f64   ┆ f64     ┆ f64  ┆ f64          │
╞══════╪═════════════╪═══════╪═════════╪══════╪══════════════╡
│ 1.01 ┆ null        ┆ 1.04  ┆ 1.04    ┆ null ┆ 1.04         │
│ 1.02 ┆ connect     ┆ 1.04  ┆ 1.04    ┆ null ┆ 1.04         │
│ 1.03 ┆ null        ┆ 1.04  ┆ 1.04    ┆ null ┆ 1.04         │
│ 1.04 ┆ player_quit ┆ 1.04  ┆ 1.04    ┆ 1.04 ┆ 1.88         │
│ 1.04 ┆ player_quit ┆ 1.04  ┆ 1.88    ┆ 1.04 ┆ 1.88         │
│ 1.65 ┆ null        ┆ 1.88  ┆ 1.88    ┆ null ┆ 1.88         │
│ 1.69 ┆ null        ┆ 1.88  ┆ 1.88    ┆ null ┆ 1.88         │
│ 1.71 ┆ connect     ┆ 1.88  ┆ 1.88    ┆ null ┆ 1.88         │
│ 1.77 ┆ null        ┆ 1.88  ┆ 1.88    ┆ null ┆ 1.88         │
│ 1.88 ┆ player_quit ┆ 1.88  ┆ 1.98    ┆ null ┆ 1.98         │
│ 1.96 ┆ connect     ┆ 1.98  ┆ 1.98    ┆ null ┆ 1.98         │
│ 1.97 ┆ null        ┆ 1.98  ┆ 1.98    ┆ null ┆ 1.98         │
│ 1.98 ┆ player_quit ┆ 1.98  ┆ null    ┆ null ┆ null         │
└──────┴─────────────┴───────┴─────────┴──────┴──────────────┘