azmyrajab / polars_ols

Polars least squares extension - enables fast linear model polar expressions
MIT License
102 stars 9 forks source link

Assertion failed: end <= axis_len when using rolling ridge #13

Closed mat-ej closed 5 months ago

mat-ej commented 5 months ago

Hi, first off thanks for this awesome package. It is amazingly useful.

When using rolling ols, e.g.

rolling_ridge = (pl.col("y")
                 .least_squares.from_formula(formula,
                                             window_size=2000,
                                             min_periods=2000, 
                                             alpha=1.0,
                                             mode="coefficients"
                                             )
                .over("datetime")
)

I do get the following exception: assertion failed: end <= axis_len

panicked at [/Users/runner/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:401:5](https://file+.vscode-resource.vscode-cdn.net/Users/runner/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:401:5):

I wonder what I might be doing wrong, all the columns in the formula exists, and are not nan, the shape of the dataframe is: (20_000, 36)

Thanks in advance for help.

azmyrajab commented 5 months ago

Hi @mat-ej, thank you vm - glad you are enjoying it!

Happy to take a look at this for you- my first thought is maybe there isn’t sufficient data per group (to span the min_periods) causing that array assertion error.

Did you intend to do the .over() on “datetime” or did you mean to sort on datetime and do the over on a grouping column (eg “symbol” / “id”)? Dont know the structure of your data so just guessing.

If above is not culprit, would it be possible to send me a sample of the your data you are running on so I can reproduce ?

Thanks for documenting your issue

azmyrajab commented 5 months ago

Hi @mat-ej - I could reproduce similar error w/ data sizes which are smaller than min_period provided, so in the next release there will be (in addition to the rolling window null handling support PR):

    if n_valid < n {
        println!(
            "warning: number of valid observations detected is less than 'min_periods' set, \
             returning NaN coefficients!"
        );
        return Array2::from_elem((n, k), f64::NAN)
    }

So you'll get a more informative/clear message and it will NaN out instead of break. Reason for doing that instead of asserting and panicing is if you did an over() some groups and some of them happened to be sparsely populated - it'll allow you to still get estimates for the rest.

It's a parallel of the edge case here https://github.com/azmyrajab/polars_ols/issues/5 but for rolling window models which handles this lack of data case with producing nulls.

Closing for now - but if anything is unclear / you disagree w/ anything, please do reach out

mat-ej commented 5 months ago

Thanks so much for help. Once again this is one of the most useful github repos I have stumbled upon in recent time.

I am really taking notes from your code how you wrote those custom expressions, it is so well written its crazy.

I do wonder how difficult it would be to write such custom expressions for some non linear models such as GAMs etc.

azmyrajab commented 5 months ago

Thank you, thats very kind :) Glad you are finding it useful!

GLMs / "kernel ridge" / non-linear SVMs are probably at the edge of what this package could, absent a big refactor, probably support in future versions. That's mainly because of the fit / predict mechanics: for linear (or linear in transformed space) the model's fit can be summarised with linear coefficients which can easily be de-serialized and so passing that back to python becomes simple and allows the user can do fit on train data and predict on test data easily - all in polars expressions.

For more complex models, the model's internal state (e.g. think decision tree nodes and weights) needs to be represented differently and one then needs to think carefully how to get pyo3 to deserialize this state to/from python (and how polars expression would 'point' to this internal state). For such arbitrary non-linear models it starts to makes more sense to build scikit-learn style objects with fit and predict and an internal state like rust's: linfa. A future project idea could be to add a pyo3 deserialization layer on top of something like linfa so that it can be called from python (and polars) but that's beyond the scope of my bandwidth here.

If one only ever wanted to do in-sample predictions/residuals it'd be easy to support arbitrary models, but as soon as you want to allow out-of-sample testing it becomes more complex to expose the state nicely, I think.

Hope this makes sense! And will keep you posted if generalized linear models eventually get supported here

mat-ej commented 5 months ago

Btw, this is imo, a very nice svm implementation:

lsorber/neo-ls-svm

not hard to register it as data_frame.namespace api. I am still getting around to implementing it as expression.namespace api so that it is usable in group_by and other contexts. Looking into your code and learning how its done : ).