abstractqqq / polars_ds_extension

Polars extension for general data science use cases
MIT License
263 stars 17 forks source link

Can lstsq support regular expressions and skip null? #94

Closed wukan1986 closed 3 months ago

wukan1986 commented 3 months ago

In actual applications, the frequency of null occurrences is too high, making it basically impossible to use normally. Nan values in numpy are easy to handle, but there is no good way to handle nulls in polars. It would be great if nulls could be handled well in rust.

I have given an example of implementation using numpy, FYI.

from typing import List

import numpy as np
import polars as pl
import polars_ds as pld  # noqa

df = pl.DataFrame({
    "A": [9, 10, 11, 12],
    "B": [1, 2, 3, 4],
}).with_row_index()

df = df.with_columns(df.to_dummies('B'))
df = df.with_columns(pl.col('A').rolling_mean(3).alias('C'))
print(df)
"""
shape: (4, 8)
┌───────┬─────┬─────┬─────┬─────┬─────┬─────┬──────┐
│ index ┆ A   ┆ B   ┆ B_1 ┆ B_2 ┆ B_3 ┆ B_4 ┆ C    │
│ ---   ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ---  │
│ u32   ┆ i64 ┆ i64 ┆ u8  ┆ u8  ┆ u8  ┆ u8  ┆ f64  │
╞═══════╪═════╪═════╪═════╪═════╪═════╪═════╪══════╡
│ 0     ┆ 9   ┆ 1   ┆ 1   ┆ 0   ┆ 0   ┆ 0   ┆ null │
│ 1     ┆ 10  ┆ 2   ┆ 0   ┆ 1   ┆ 0   ┆ 0   ┆ null │
│ 2     ┆ 11  ┆ 3   ┆ 0   ┆ 0   ┆ 1   ┆ 0   ┆ 10.0 │
│ 3     ┆ 12  ┆ 4   ┆ 0   ┆ 0   ┆ 0   ┆ 1   ┆ 11.0 │
└───────┴─────┴─────┴─────┴─────┴─────┴─────┴──────┘
"""
# df = df.with_columns(pl.col('A').num.lstsq(pl.col('B_1'), pl.col('B_2'), pl.col('B_3'), pl.col('B_4'), pl.col('C'), return_pred=True).struct.field('resid'))
print(df)
"""
polars.exceptions.ComputeError: the plugin failed with message: Lstsq: Input must not contain nulls and must have length > 1
"""

def residual_multiple(cols: List[pl.Series], add_constant: bool) -> pl.Series:
    cols = [c.to_numpy() for c in cols[0].struct]
    if add_constant:
        cols += [np.ones_like(cols[0])]
    yx = np.vstack(cols).T

    # skip nan
    mask = np.any(np.isnan(yx), axis=1)
    yx_ = yx[~mask, :]

    y = yx_[:, 0]
    x = yx_[:, 1:]
    coef = np.linalg.lstsq(x, y, rcond=None)[0]
    y_hat = np.sum(x * coef, axis=1)
    residual = y - y_hat

    # refill
    out = np.empty_like(yx[:, 0])
    out[~mask] = residual
    out[mask] = np.nan
    return pl.Series(out, nan_to_null=True)

def cs_neutralize_residual_multiple(y: pl.Expr, *more_x: pl.Expr, add_constant: bool = False) -> pl.Expr:
    return pl.map_batches(pl.struct([y, *more_x]), lambda xx: residual_multiple(xx, add_constant))

df = df.with_columns(cs_neutralize_residual_multiple(pl.col('A'), pl.col('^B_.*$'), pl.col('C')).alias('resid'))
print(df)
"""
shape: (4, 9)
┌───────┬─────┬─────┬─────┬───┬─────┬─────┬──────┬─────────────┐
│ index ┆ A   ┆ B   ┆ B_1 ┆ … ┆ B_3 ┆ B_4 ┆ C    ┆ resid       │
│ ---   ┆ --- ┆ --- ┆ --- ┆   ┆ --- ┆ --- ┆ ---  ┆ ---         │
│ u32   ┆ i64 ┆ i64 ┆ u8  ┆   ┆ u8  ┆ u8  ┆ f64  ┆ f64         │
╞═══════╪═════╪═════╪═════╪═══╪═════╪═════╪══════╪═════════════╡
│ 0     ┆ 9   ┆ 1   ┆ 1   ┆ … ┆ 0   ┆ 0   ┆ null ┆ null        │
│ 1     ┆ 10  ┆ 2   ┆ 0   ┆ … ┆ 0   ┆ 0   ┆ null ┆ null        │
│ 2     ┆ 11  ┆ 3   ┆ 0   ┆ … ┆ 1   ┆ 0   ┆ 10.0 ┆ -1.0658e-14 │
│ 3     ┆ 12  ┆ 4   ┆ 0   ┆ … ┆ 0   ┆ 1   ┆ 11.0 ┆ -8.8818e-15 │
└───────┴─────┴─────┴─────┴───┴─────┴─────┴──────┴─────────────┘
"""
abstractqqq commented 3 months ago

Just want to make sure this is the desired behavior: Skip row if any value in row is null. Yes it is doable and actually not too hard in the current Rust code base. I will add it as an additional flag.

Regarding the regex, I currently do not see an easy way. I can provide you with helper functions, but that won't solve the issue from its root.

abstractqqq commented 3 months ago

Closing this for now. Skip null is implemented in: https://github.com/abstractqqq/polars_ds_extension/commit/837560d0a475c4fba5790a6f0b525e9d6da40421