pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.2k stars 1.84k forks source link

Keep columns in `.rolling() #18084

Open MariusMerkleQC opened 1 month ago

MariusMerkleQC commented 1 month ago

Description

Problem

When using pl.DataFrame.rolling(), it is only possible to compute aggregated values, but sometimes I just like to keep a certain column.

Example

Imagine that I have a fake data in df_fake. I would like to compute a pl.DataFrame which looks just that keeps the columns "timestamp" and "key", but computes the cumulative mean up to that point in time. This is not possible using .rolling() because there is no operation which just keeps the element. Using .last(), as shown below, fails if there are equal values in the index_column. The only way I manage to work around this is by horizontally concatenating a part of the original df_fake to the rolled data frame, which doesn't look nice at all.

import polars as pl
from datetime import datetime

df_fake = pl.DataFrame(
    data=[
        (datetime(2023, 1, 1, 0, 0, 0), 1, "a"),
        (datetime(2024, 1, 1, 0, 0, 0), 2, "b"),
        (datetime(2024, 1, 1, 0, 0, 0), 3, "c"),
    ],
    schema={"timestamp": pl.Datetime, "value": pl.Int32, "key": pl.Utf8},
    orient="row",
)

df_desired = pl.DataFrame(
    data=[
        (datetime(2023, 1, 1, 0, 0, 0), 1, "a"),
        (datetime(2024, 1, 1, 0, 0, 0), 2, "b"),
        (datetime(2024, 1, 1, 0, 0, 0), 2, "c"),
    ],
    schema={"timestamp": pl.Datetime, "cumulative_mean": pl.Int32, "key": pl.Utf8},
    orient="row",
)

df_rolled = df_fake.rolling(index_column="timestamp", period="5y").agg(
    pl.col("value").mean().alias("cumulative_mean"),
    pl.col("key").last().alias("key"),
)

df_workaround = pl.concat(
    items=[
        df_fake.rolling(index_column="timestamp", period="5y").agg(
            pl.col("value").mean().alias("cumulative_mean"),
        ),
        df_fake.select(pl.col("key")),
    ],
    how="horizontal",
)

Suggestion

What about introducing an optional argument keep_cols: list[str] that just keeps the columns as they are in the original df_fake, yet they don't get lost in the .rolling() operation?

cmdlineluser commented 1 month ago

There is Expr.rolling

df_fake.with_columns(
    pl.col("value").mean().rolling(index_column="timestamp", period="5y")
      .alias("cumulative_mean")
)

# shape: (3, 4)
# ┌─────────────────────┬───────┬─────┬─────────────────┐
# │ timestamp           ┆ value ┆ key ┆ cumulative_mean │
# │ ---                 ┆ ---   ┆ --- ┆ ---             │
# │ datetime[μs]        ┆ i32   ┆ str ┆ f64             │
# ╞═════════════════════╪═══════╪═════╪═════════════════╡
# │ 2023-01-01 00:00:00 ┆ 1     ┆ a   ┆ 1.0             │
# │ 2024-01-01 00:00:00 ┆ 2     ┆ b   ┆ 2.0             │
# │ 2024-01-01 00:00:00 ┆ 3     ┆ c   ┆ 2.0             │
# └─────────────────────┴───────┴─────┴─────────────────┘

And dedicated rolling aggs: Expr.rolling_mean_by

df_fake.with_columns(
    pl.col("value").rolling_mean_by("timestamp", window_size="5y")
      .alias("cumulative_mean")
)

# shape: (3, 4)
# ┌─────────────────────┬───────┬─────┬─────────────────┐
# │ timestamp           ┆ value ┆ key ┆ cumulative_mean │
# │ ---                 ┆ ---   ┆ --- ┆ ---             │
# │ datetime[μs]        ┆ i32   ┆ str ┆ f64             │
# ╞═════════════════════╪═══════╪═════╪═════════════════╡
# │ 2023-01-01 00:00:00 ┆ 1     ┆ a   ┆ 1.0             │
# │ 2024-01-01 00:00:00 ┆ 2     ┆ b   ┆ 2.0             │
# │ 2024-01-01 00:00:00 ┆ 3     ┆ c   ┆ 2.0             │
# └─────────────────────┴───────┴─────┴─────────────────┘
MariusMerkleQC commented 1 month ago

Yes, but I'm shying away from using these as they are considered unstable...

image