abstractqqq / polars_ds_extension

Polars extension for general data science use cases
MIT License
263 stars 17 forks source link

lstsq `skip_null` not work #97

Closed wukan1986 closed 3 months ago

wukan1986 commented 3 months ago
from typing import List

import numpy as np
import polars as pl
import polars_ds as pld  # noqa

df = pl.DataFrame({
    "A": [9, 10, 11, 12],
    "B": [1, 2, 3, 4],
}).with_row_index()

df = df.with_columns(df.to_dummies('B'))
df = df.with_columns(pl.col('A').rolling_mean(2).alias('C'))
print(df)
"""
shape: (4, 8)
┌───────┬─────┬─────┬─────┬─────┬─────┬─────┬──────┐
│ index ┆ A   ┆ B   ┆ B_1 ┆ B_2 ┆ B_3 ┆ B_4 ┆ C    │
│ ---   ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ---  │
│ u32   ┆ i64 ┆ i64 ┆ u8  ┆ u8  ┆ u8  ┆ u8  ┆ f64  │
╞═══════╪═════╪═════╪═════╪═════╪═════╪═════╪══════╡
│ 0     ┆ 9   ┆ 1   ┆ 1   ┆ 0   ┆ 0   ┆ 0   ┆ null │
│ 1     ┆ 10  ┆ 2   ┆ 0   ┆ 1   ┆ 0   ┆ 0   ┆ 9.5  │
│ 2     ┆ 11  ┆ 3   ┆ 0   ┆ 0   ┆ 1   ┆ 0   ┆ 10.5 │
│ 3     ┆ 12  ┆ 4   ┆ 0   ┆ 0   ┆ 0   ┆ 1   ┆ 11.5 │
└───────┴─────┴─────┴─────┴─────┴─────┴─────┴──────┘
"""
# polars.exceptions.ComputeError: the plugin failed with message: Lstsq: Data must have more rows than columns.
df = df.with_columns(pl.col('C').num.lstsq(pl.col('B_1'), pl.col('B_2'), pl.col('B_3'), pl.col('B_4'), return_pred=True, skip_null=True).struct.field('resid'))
print(df)

df = df.with_columns(pl.col('C').num.lstsq(pl.col('A'), return_pred=True, skip_null=True).struct.field('resid'))
print(df)
"""
panicked at /home/kan/.cargo/registry/src/github.com-1ecc6299db9ec823/polars-core-0.38.1/src/chunked_array/ops/arity.rs:861:14:
Cannot apply operation on arrays of different lengths
Traceback (most recent call last):
  File "/home/kan/test1/b.py", line 32, in <module>
    df = df.with_columns(pl.col('A').num.lstsq(pl.col('C'), return_pred=True, skip_null=True).struct.field('resid'))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kan/miniconda3/envs/py311/lib/python3.11/site-packages/polars/dataframe/frame.py", line 8289, in with_columns
    return self.lazy().with_columns(*exprs, **named_exprs).collect(_eager=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kan/miniconda3/envs/py311/lib/python3.11/site-packages/polars/lazyframe/frame.py", line 1934, in collect
    return wrap_df(ldf.collect())
                   ^^^^^^^^^^^^^
polars.exceptions.ComputeError: the plugin panickedanic message to stderr.
└───────┴─────┴─────┴─────┴───┴─────┴─────┴──────┴───────┘
"""
wukan1986 commented 3 months ago

FYI

from typing import List

import numpy as np
import polars as pl
import polars_ds as pld  # noqa

df = pl.DataFrame({
    "A": [9, 10, 11, 12],
    "B": [1, 2, 3, 4],
}).with_row_index()

df = df.with_columns(df.to_dummies('B'))
df = df.with_columns(pl.col('A').rolling_mean(4).alias('C'))
print(df)

def residual_multiple(cols: List[pl.Series], add_constant: bool) -> pl.Series:
    cols = [list(c.struct) if isinstance(c.dtype, pl.Struct) else [c] for c in cols]
    cols = [i.to_numpy() for p in cols for i in p]
    if add_constant:
        cols += [np.ones_like(cols[0])]
    yx = np.vstack(cols).T

    # skip nan
    mask = np.any(np.isnan(yx), axis=1)
    yx_ = yx[~mask, :]

    y = yx_[:, 0]
    x = yx_[:, 1:]
    coef = np.linalg.lstsq(x, y, rcond=None)[0]
    y_hat = np.sum(x * coef, axis=1)
    residual = y - y_hat

    # refill
    out = np.empty_like(yx[:, 0])
    out[~mask] = residual
    out[mask] = np.nan
    return pl.Series(out, nan_to_null=True)

def cs_neutralize_residual_multiple(y: pl.Expr, *more_x: pl.Expr, add_constant: bool = False) -> pl.Expr:
    return pl.map_batches([y, *more_x], lambda xx: residual_multiple(xx, add_constant))

df = df.with_columns(cs_neutralize_residual_multiple(pl.col('A'), pl.col('C')).alias('resid1'))
print(df)
df = df.with_columns(cs_neutralize_residual_multiple(pl.col('C'), pl.struct('^B_.*$'), pl.col('A')).alias('resid2'))
print(df)
"""
shape: (4, 10)
┌───────┬─────┬─────┬─────┬───┬─────┬──────┬────────┬─────────────┐
│ index ┆ A   ┆ B   ┆ B_1 ┆ … ┆ B_4 ┆ C    ┆ resid1 ┆ resid2      │
│ ---   ┆ --- ┆ --- ┆ --- ┆   ┆ --- ┆ ---  ┆ ---    ┆ ---         │
│ u32   ┆ i64 ┆ i64 ┆ u8  ┆   ┆ u8  ┆ f64  ┆ f64    ┆ f64         │
╞═══════╪═════╪═════╪═════╪═══╪═════╪══════╪════════╪═════════════╡
│ 0     ┆ 9   ┆ 1   ┆ 1   ┆ … ┆ 0   ┆ null ┆ null   ┆ null        │
│ 1     ┆ 10  ┆ 2   ┆ 0   ┆ … ┆ 0   ┆ null ┆ null   ┆ null        │
│ 2     ┆ 11  ┆ 3   ┆ 0   ┆ … ┆ 0   ┆ null ┆ null   ┆ null        │
│ 3     ┆ 12  ┆ 4   ┆ 0   ┆ … ┆ 1   ┆ 10.5 ┆ 0.0    ┆ -1.7764e-15 │
└───────┴─────┴─────┴─────┴───┴─────┴──────┴────────┴─────────────┘
"""
abstractqqq commented 3 months ago

On it!

abstractqqq commented 3 months ago

The issue with error 1 is that there can be infinitely many solutions and the linalg backend I am using may have issues with that. (I am using QR to solve least square, using the crate Faer-rs.. Got some panics if I don't put in this guard) I think this is a reasonable guard to put in place.

The second error is caused by a mistake which should be fixed in https://github.com/abstractqqq/polars_ds_extension/pull/98