pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.28k stars 1.85k forks source link

0.20.6->1.0.0 slowdown of rolling() with complex aggregations #17423

Closed 1112114641 closed 1 month ago

1112114641 commented 2 months ago

Checks

Reproducible example

import yfinance as yf
import polars as pl
from datetime import datetime

s = pl.DataFrame(
  yf.download("AAPL ^GSPC ^GDAXI ^N225 ^AXJO ES=F", start="2022-07-06", interval="1h", end="2024-07-04").reset_index()
).with_columns(indx=pl.int_range(pl.len()))

def ols_reg(x: pl.Expr, y: pl.Expr, pred_dist: pl.Expr) -> pl.Expr:
  """Calculate linear regression a * x + b
  for large datasets, normalise x/y to [0,1]
  """
  n = y.count()
  x_sum = x.sum()
  y_sum = y.sum()
  x_s_sq = (x**2).sum()
  xy_sum = (x * y).sum()
  slope = (n * xy_sum - x_sum * y_sum + pl.lit(1e-10)) / (n * x_s_sq - x_sum**2 + pl.lit(1e-10))
  offset = (y_sum - slope * x_sum) / n
  return slope * (x.last() + pred_dist) + offset

def quad_reg(x: pl.Expr, y: pl.Expr, pred_dist: pl.Expr) -> pl.Expr:
  """Calculate quad regression slope ax^2 + bx + c
  for large datasets >1k rows, normalise x/y to [0,1]
  """
  n = y.count()
  c_11 = (x**2).sum() - x.mean() ** 2 * n
  c_12 = (x**3).sum() - (x.mean() * (x**2).mean()) * n
  c_22 = (x**4).sum() - ((x**2).mean() ** 2) * n
  c_y1 = (x * y).sum() - (x.mean() * y.mean()) * n
  c_y2 = (x**2 * y).sum() - (x**2).mean() * (y.mean()) * n
  a = (c_y2 * c_11 - c_y1 * c_12 + pl.lit(1e-10)) / (c_22 * c_11 - c_12.mul(c_12) + pl.lit(1e-10))
  b = (c_y1 * c_22 - c_y2 * c_12 + pl.lit(1e-10)) / (c_22 * c_11 - c_12.mul(c_12) + pl.lit(1e-10))
  c = y.mean() - (b * x.mean()) - (a * (x**2).mean())
  return a * (x.last() + pred_dist) ** 2 + b * (x.last() + pred_dist) + c

print(start:= datetime.now())
s.sort("('Datetime', '')").rolling(index_column="('Datetime', '')", period="2h").agg(
  [
    pl.exclude("('Datetime', '')").last(),
    pl.col("indx").pow(2).alias("1"),
    pl.col("indx").pow(3).alias("2"),
    pl.col("('Close', 'ES=F')").pow(2).alias("3"),
    pl.col("('Close', 'ES=F')").pow(3).alias("4"),
    pl.col("indx").sqrt().alias("5"),
    ols_reg(pl.col("indx"), pl.col("('Close', 'ES=F')"), pl.lit(1)).alias("hourly_lin_pred"),
    quad_reg(pl.col("indx"), pl.col("('Close', 'ES=F')"), pl.lit(1)).alias("hourly_quad_pred"),
    pl.when(pl.len() >= 15).then(pl.lit(True)).otherwise(pl.lit(False)).alias("hourly_mask"),
  ]
)
print(end:=datetime.now(), f"{(end-start)/timedelta(seconds=1):.2f}s")

Log output

No response

Issue description

Running the above code a couple of times, I get the following results:


0.20.6/0.20.31: 
141 ms ± 27.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)  (without exclude)
142 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) (with exclude line)

1.0.0: 
309 ms ± 57.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (without exclude)
535 ms ± 38.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (with exclude line)

the difference between 0.20.x/1.0.0 scales with the compute intensiveness of the aggregations - possibly this will help narrowing down the error source.

Expected behavior

comparable run times for 0.20.x / 1.0.0

Installed versions

``` --------Version info--------- Polars: 0.20.6 / 0.20.31 / 1.0.0 Index type: UInt32 Platform: macOS-14.5-arm64-arm-64bit Python: 3.10.6 (main, Aug 7 2023, 13:38:39) [Clang 14.0.3 (clang-1403.0.22.14.1)] ----Optional dependencies---- adbc_driver_manager: cloudpickle: connectorx: deltalake: fastexcel: fsspec: gevent: great_tables: hvplot: matplotlib: nest_asyncio: 1.6.0 numpy: 2.0.0 openpyxl: pandas: 2.2.2 pyarrow: 16.1.0 pydantic: pyiceberg: sqlalchemy: 2.0.31 torch: xlsx2csv: 0.8.2 xlsxwriter: ```
ritchie46 commented 2 months ago

This is not really actionable. Can you create a minimal example that shows your case on a single operation?

1112114641 commented 2 months ago

smallest toy example I got to work:

from datetime import datetime

import polars as pl

a = pl.datetime_range(
  datetime(2020, 1, 1, 1, 1),
  datetime(2024, 7, 5, 3, 1),
  interval="1s",
  eager=True,
)

banana = (
  pl.DataFrame({"dates": a, "idx": range(len(a)), "vals": range(len(a))})
  .with_columns(
    mask=pl.col("dates").dt.weekday().gt(5).or_(pl.col("dates").dt.hour().gt(20)), vals=pl.col("vals").cast(pl.Float64)
  )
  .with_columns(
    pl.when(pl.col("mask")).then(pl.lit(None)).otherwise(pl.col("idx")).alias("idx"),
    pl.when(pl.col("mask")).then(pl.col("vals")).otherwise(pl.lit(None)).alias("vals"),
  )
)

banana.rolling(index_column="dates", period="2h").agg(pl.exclude("dates").last())
# 4.54 s ± 46.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) - 0.20.6

# 5.28 s ± 73.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) - 1.0.0

In comparison to my original dataset + pipeline (~2m rows /150cols + rolling operation across 3 timeframes (2h, 2d, 2w) with ~2.5mins difference), the difference here is very, very small, but reproducible.

ritchie46 commented 2 months ago

@stinodego I am not convinced this is a regression. There might be we do something more correct now, or it might be due to a rustc update. I want to pin down to a single operation/ commit to confirm.

If someone wants to get a bisect on this.

1112114641 commented 2 months ago

Interesting, I just finished a git bisect, and it seems the first / second changes wrt 0.20.6 are different ones: the second example shows the slow behaviour after 0.20.22rc1

f005b98c579c4a9d386518a6db25fc26d0b204ac is the first bad commit
commit f005b98c579c4a9d386518a6db25fc26d0b204ac
Author: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Date:   Sat Apr 20 20:41:08 2024 +0200

    build(rust): bump rustls from 0.21.10 to 0.21.11 (#15792)

    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

 Cargo.lock | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

which is the bump from 0.48.5->0.52.4 for libloading dependency windows-targets.

the first one after

42ba1b02730da9e83c413c2ec0d86f703b4e98cc is the first bad commit
commit 42ba1b02730da9e83c413c2ec0d86f703b4e98cc
Author: Marc Garcia <garcia.marc@gmail.com>
Date:   Mon Jun 24 14:54:34 2024 +0400

   test(rust): Add a test for AnonymousScan options (projection and slice pushdown) (#17149)

crates/polars-lazy/src/tests/io.rs | 39 ++++++++++++++++++++++++++++++++++++++
1 file changed, 39 insertions(+)

which points to fn scan_anonymous_fn_with_options() and likely is a dud - I would still love to know what causes the explosion in wall-time I am seeing for that first example ¯_(ツ)_/¯.

Does this help with further analysis?