abstractqqq / polars_ds_extension

Polars extension for general data science use cases
MIT License
263 stars 17 forks source link

lstsq produces incorrect results in `group_by` context when `return_pred=True` #142

Closed erinov1 closed 2 months ago

erinov1 commented 2 months ago

I don't really understand this one, but lstsq produces different results between the select and group_by contexts when return_pred=True, with the latter being incorrect.

Setup:

df = pl.DataFrame(
    {
        "A": [1] * 4 + [2] * 4,
        "Y": [1] * 8,
        "X1": [1, 2, 3, 4, 5, 6, 7, 8],
        "X2": [2, 3, 4, 1, 6, 7, 8, 5],
    }
)

Correct results in select context (agrees with numpy)


df.filter(pl.col("A").eq(1)).with_columns(
    pds.query_lstsq(
        pl.col("X1"), pl.col("X2"), target=pl.col("Y"), add_bias=False, return_pred=True
    ).alias("pred")
)

shape: (4, 5)
┌─────┬─────┬─────┬─────┬──────────────────────┐
│ A   ┆ Y   ┆ X1  ┆ X2  ┆ pred                 │
│ --- ┆ --- ┆ --- ┆ --- ┆ ---                  │
│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ struct[2]            │
╞═════╪═════╪═════╪═════╪══════════════════════╡
│ 1   ┆ 1   ┆ 1   ┆ 2   ┆ {0.555556,0.444444}  │
│ 1   ┆ 1   ┆ 2   ┆ 3   ┆ {0.925926,0.074074}  │
│ 1   ┆ 1   ┆ 3   ┆ 4   ┆ {1.296296,-0.296296} │
│ 1   ┆ 1   ┆ 4   ┆ 1   ┆ {0.925926,0.074074}  │
└─────┴─────┴─────┴─────┴──────────────────────┘

Different, incorrect results in group_by context (look at A=1 for example and compare to above)

df.group_by("A", maintain_order=True).agg(
    "Y",
    "X1",
    "X2",
    pds.query_lstsq(
        pl.col("X1"), pl.col("X2"), target=pl.col("Y"), add_bias=False, return_pred=True
    ).alias("pred"),
).explode("Y", "X1", "X2", "pred")

shape: (8, 5)
┌─────┬─────┬─────┬─────┬──────────────────────┐
│ A   ┆ Y   ┆ X1  ┆ X2  ┆ pred                 │
│ --- ┆ --- ┆ --- ┆ --- ┆ ---                  │
│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ struct[2]            │
╞═════╪═════╪═════╪═════╪══════════════════════╡
│ 1   ┆ 1   ┆ 1   ┆ 2   ┆ {0.272727,0.727273}  │
│ 1   ┆ 1   ┆ 2   ┆ 3   ┆ {0.454545,0.545455}  │
│ 1   ┆ 1   ┆ 3   ┆ 4   ┆ {0.636364,0.363636}  │
│ 1   ┆ 1   ┆ 4   ┆ 1   ┆ {0.454545,0.545455}  │
│ 2   ┆ 1   ┆ 5   ┆ 6   ┆ {1.0,0.0}            │
│ 2   ┆ 1   ┆ 6   ┆ 7   ┆ {1.181818,-0.181818} │
│ 2   ┆ 1   ┆ 7   ┆ 8   ┆ {1.363636,-0.363636} │
│ 2   ┆ 1   ┆ 8   ┆ 5   ┆ {1.181818,-0.181818} │
└─────┴─────┴─────┴─────┴──────────────────────┘

On the other hand, the actual coefficients returned in the select and group_by contexts agree and are both correct:

df.filter(pl.col("A").eq(1)).with_columns(
    pds.query_lstsq(pl.col("X1"), pl.col("X2"), target=pl.col("Y"), add_bias=False).alias("pred")
)

shape: (4, 5)
┌─────┬─────┬─────┬─────┬──────────────────────┐
│ A   ┆ Y   ┆ X1  ┆ X2  ┆ pred                 │
│ --- ┆ --- ┆ --- ┆ --- ┆ ---                  │
│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ list[f64]            │
╞═════╪═════╪═════╪═════╪══════════════════════╡
│ 1   ┆ 1   ┆ 1   ┆ 2   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 2   ┆ 3   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 3   ┆ 4   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 4   ┆ 1   ┆ [0.185185, 0.185185] │
└─────┴─────┴─────┴─────┴──────────────────────┘

df.group_by("A", maintain_order=True).agg(
    "Y",
    "X1",
    "X2",
    pds.query_lstsq(
        pl.col("X1"),
        pl.col("X2"),
        target=pl.col("Y"),
        add_bias=False,
    ).alias("pred"),
).explode("Y", "X1", "X2")

shape: (8, 5)
┌─────┬─────┬─────┬─────┬──────────────────────┐
│ A   ┆ Y   ┆ X1  ┆ X2  ┆ pred                 │
│ --- ┆ --- ┆ --- ┆ --- ┆ ---                  │
│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ list[f64]            │
╞═════╪═════╪═════╪═════╪══════════════════════╡
│ 1   ┆ 1   ┆ 1   ┆ 2   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 2   ┆ 3   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 3   ┆ 4   ┆ [0.185185, 0.185185] │
│ 1   ┆ 1   ┆ 4   ┆ 1   ┆ [0.185185, 0.185185] │
│ 2   ┆ 1   ┆ 5   ┆ 6   ┆ [0.076023, 0.076023] │
│ 2   ┆ 1   ┆ 6   ┆ 7   ┆ [0.076023, 0.076023] │
│ 2   ┆ 1   ┆ 7   ┆ 8   ┆ [0.076023, 0.076023] │
│ 2   ┆ 1   ┆ 8   ┆ 5   ┆ [0.076023, 0.076023] │
└─────┴─────┴─────┴─────┴──────────────────────┘

Unfortunately I can't tell from the code why this is happening.

abstractqqq commented 2 months ago

https://github.com/abstractqqq/polars_ds_extension/pull/143

The reason is very subtle... I think I misunderstood what is_elementwise means in polar's plugin register function. I put it there and somehow it messed up the group by. I noticed removing it solves the problem. Still don't know how exactly is_elementwise affects the execution, but the issue should be resolved.

Thanks a lot for pointing it out!

erinov1 commented 2 months ago

Thanks for the quick fix!