azmyrajab / polars_ols

Polars least squares extension - enables fast linear model polar expressions
MIT License
97 stars 9 forks source link

`predict()` with `add_intercept=True` fails when features are multi-column expression #25

Closed kstoneriv3 closed 2 weeks ago

kstoneriv3 commented 1 month ago

Thanks for creating this awesome library! I found it very handy not having to go back and forth between numpy and polars world for fitting a simple model!

When I run something like the following, I get an error because the add_intercept=True creates multiple constant columns instead of a single constant column internally and therefore it fails.

import polars as pl
import polars.selectors as cs

df = pl.DataFrame({"y": [1, 2, 3, 4], "x1": [3, 4, 5, 6], "x2": [4, 5, 6, 7], "x3": [5, 6, 7, 8]})
df = df.with_columns(
    pl.col("y").least_squares.ols(
        cs.starts_with("x"),
        add_intercept=True,
        mode="coefficients",
    )
)
# This throws an error!
df = df.with_columns(
    pl.col("coefficients").least_squares.predict(
        cs.starts_with("x"),
        add_intercept=True,
    ).alias("y_pred")
)
panicked at src/expressions.rs:649:5:
assertion `left == right` failed: number of coefficients must match number of features!
  left: 4
 right: 6
---------------------------------------------------------------------------
ComputeError                              Traceback (most recent call last)
Cell In[55], line 12
      4 df = pl.DataFrame({"y": [1, 2, 3, 4], "x1": [3, 4, 5, 6], "x2": [4, 5, 6, 7], "x3": [5, 6, 7, 8]})
      5 df = df.with_columns(
      6     pl.col("y").least_squares.ols(
      7         cs.starts_with("x"),
   (...)
     10     )
     11 )
---> 12 df = df.with_columns(
     13     pl.col("coefficients").least_squares.predict(
     14         cs.starts_with("x"),
     15         add_intercept=True,
     16     ).alias("y_pred")
     17 )

File ~/Library/Caches/pypoetry/virtualenvs/kingston-gaisan-dhDkHfdI-py3.12/lib/python3.12/site-packages/polars/dataframe/frame.py:8791, in DataFrame.with_columns(self, *exprs, **named_exprs)
   8645 def with_columns(
   8646     self,
   8647     *exprs: IntoExpr | Iterable[IntoExpr],
   8648     **named_exprs: IntoExpr,
   8649 ) -> DataFrame:
   8650     """
   8651     Add columns to this DataFrame.
   8652 
   (...)
   8789     └─────┴──────┴─────────────┘
   8790     """
-> 8791     return self.lazy().with_columns(*exprs, **named_exprs).collect(_eager=True)

File ~/Library/Caches/pypoetry/virtualenvs/kingston-gaisan-dhDkHfdI-py3.12/lib/python3.12/site-packages/polars/lazyframe/frame.py:1942, in LazyFrame.collect(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, no_optimization, streaming, background, _eager, **_kwargs)
   1939 # Only for testing purposes atm.
   1940 callback = _kwargs.get("post_opt_callback")
-> 1942 return wrap_df(ldf.collect(callback))

ComputeError: the plugin panicked

This is happens at the following line at https://github.com/azmyrajab/polars_ols/blob/a5e13a07545a41d2326e0957d2717e44f3343c38/polars_ols/least_squares.py#L482C1-L483C1

            features += (features[-1].fill_null(0.0).mul(0.0).add(1.0).alias("const"),)

I think the simplest fix would be to replace it with

            features += (pl.lit(1.0).alias("const"),)

but I am not sure if that's the right thing so I created this PR.

azmyrajab commented 1 month ago

Thank you for documenting this and glad you are enjoying the library !

this sounds easy enough to fix in python (check if const already exists and ensure it’s not duplicated) - will try to do that when I get a moment. FYI the weird looking features[-1].fill_null(… was needed vs just a literal due to some issues with size broadcasting in rust side last I worked on this

kstoneriv3 commented 1 month ago

Thank you for the reply! I am looking forward to it (though I am in no rush)!!

azmyrajab commented 2 weeks ago

This should now be resolved - thanks for raising this !

your proposed fix went through w/o issues: https://github.com/azmyrajab/polars_ols/blob/main/polars_ols/least_squares.py#L482C26-L482C32

kstoneriv3 commented 2 weeks ago

Thanks a lot for the update!!!