abstractqqq / polars_ds_extension

Polars extension for general data science use cases
MIT License
263 stars 17 forks source link

`np.linalg.lstsq` 7x fast than `num.lstsq` #99

Closed wukan1986 closed 3 months ago

wukan1986 commented 3 months ago

np.linalg.lstsq 7x fast than num.lstsq

from typing import List

import time
import numpy as np
import polars as pl
import polars_ds as pld  # noqa

df = pl.DataFrame({
    "A": pl.int_range(10000, eager=True),
    "B": pl.int_range(10000, eager=True)+1,
}).with_row_index()

df = df.with_columns(df.to_dummies('B'))
df = df.with_columns(pl.col('A').rolling_mean(2).alias('C'))
print(df)

def residual_multiple(cols: List[pl.Series], add_constant: bool) -> pl.Series:
    cols = [list(c.struct) if isinstance(c.dtype, pl.Struct) else [c] for c in cols]
    cols = [i.to_numpy() for p in cols for i in p]
    if add_constant:
        cols += [np.ones_like(cols[0])]
    yx = np.vstack(cols).T

    # skip nan
    mask = np.any(np.isnan(yx), axis=1)
    yx_ = yx[~mask, :]

    y = yx_[:, 0]
    x = yx_[:, 1:]
    coef = np.linalg.lstsq(x, y, rcond=None)[0]
    y_hat = np.sum(x * coef, axis=1)
    residual = y - y_hat

    # refill
    out = np.empty_like(yx[:, 0])
    out[~mask] = residual
    out[mask] = np.nan
    return pl.Series(out, nan_to_null=True)

def cs_neutralize_residual_multiple(y: pl.Expr, *more_x: pl.Expr, add_constant: bool = False) -> pl.Expr:
    return pl.map_batches([y, *more_x], lambda xx: residual_multiple(xx, add_constant))

x = df.with_columns([
    cs_neutralize_residual_multiple(pl.col('A'), pl.col('C')).alias('resid1'),
    pl.col('A').num.lstsq(pl.col('C'), return_pred=True, skip_null=True).struct.field('resid').alias('resid2'),
])
print(x)
"""
shape: (10_000, 10_006)
┌───────┬──────┬───────┬─────┬───┬────────┬────────┬───────────┬───────────┐
│ index ┆ A    ┆ B     ┆ B_1 ┆ … ┆ B_9999 ┆ C      ┆ resid1    ┆ resid2    │
│ ---   ┆ ---  ┆ ---   ┆ --- ┆   ┆ ---    ┆ ---    ┆ ---       ┆ ---       │
│ u32   ┆ i64  ┆ i64   ┆ u8  ┆   ┆ u8     ┆ f64    ┆ f64       ┆ f64       │
╞═══════╪══════╪═══════╪═════╪═══╪════════╪════════╪═══════════╪═══════════╡
│ 0     ┆ 0    ┆ 1     ┆ 1   ┆ … ┆ 0      ┆ null   ┆ null      ┆ NaN       │
│ 1     ┆ 1    ┆ 2     ┆ 0   ┆ … ┆ 0      ┆ 0.5    ┆ 0.499962  ┆ 0.499962  │
│ 2     ┆ 2    ┆ 3     ┆ 0   ┆ … ┆ 0      ┆ 1.5    ┆ 0.499887  ┆ 0.499887  │
│ 3     ┆ 3    ┆ 4     ┆ 0   ┆ … ┆ 0      ┆ 2.5    ┆ 0.499812  ┆ 0.499812  │
│ 4     ┆ 4    ┆ 5     ┆ 0   ┆ … ┆ 0      ┆ 3.5    ┆ 0.499737  ┆ 0.499737  │
│ …     ┆ …    ┆ …     ┆ …   ┆ … ┆ …      ┆ …      ┆ …         ┆ …         │
│ 9995  ┆ 9995 ┆ 9996  ┆ 0   ┆ … ┆ 0      ┆ 9994.5 ┆ -0.249662 ┆ -0.249662 │
│ 9996  ┆ 9996 ┆ 9997  ┆ 0   ┆ … ┆ 0      ┆ 9995.5 ┆ -0.249737 ┆ -0.249737 │
│ 9997  ┆ 9997 ┆ 9998  ┆ 0   ┆ … ┆ 0      ┆ 9996.5 ┆ -0.249812 ┆ -0.249812 │
│ 9998  ┆ 9998 ┆ 9999  ┆ 0   ┆ … ┆ 1      ┆ 9997.5 ┆ -0.249887 ┆ -0.249887 │
│ 9999  ┆ 9999 ┆ 10000 ┆ 0   ┆ … ┆ 0      ┆ 9998.5 ┆ -0.249962 ┆ -0.249962 │
└───────┴──────┴───────┴─────┴───┴────────┴────────┴───────────┴───────────┘
"""
t0 = time.perf_counter()
for i in range(100):
    df.select(cs_neutralize_residual_multiple(pl.col('A'), pl.col('C')).alias('resid1'))
t1 = time.perf_counter()
print(t1-t0, 'np.linalg.lstsq')

t0 = time.perf_counter()
for i in range(100):
    df.select(pl.col('A').num.lstsq(pl.col('C'), return_pred=True, skip_null=True).struct.field('resid').alias('resid2'))
t1 = time.perf_counter()
print(t1-t0, 'num.lstsq')

"""
0.08975454200117383 np.linalg.lstsq
0.5801689060026547 num.lstsq
"""
abstractqqq commented 3 months ago

I changed the algorithm and right now it runs about the same as NumPy.

image

Although I think we can speed it up further for the case of a single X column. I will update you.

abstractqqq commented 3 months ago

Turns out it is not worth it to short-cut the single variable case, especially when nulls are present. Anyway, lstsq speed is on par with NumPy and even more stable on my machine.

FYI, I also changed to fat lto, which may improve performance by a tiny little bit but will make compiling much longer. If you are running locally, be aware of this.

https://github.com/abstractqqq/polars_ds_extension/pull/100

abstractqqq commented 3 months ago

100k rows, 10 predictive variables

image

wukan1986 commented 3 months ago

Great job!!!