simon-hirsch / rolch

A package for online learning for distributional regression and online models for conditional heteroskedasiticity
https://simon-hirsch.github.io/rolch/
MIT License
9 stars 2 forks source link

Draft: Refactor API #25

Closed simon-hirsch closed 3 days ago

simon-hirsch commented 2 weeks ago

Starting a PR to refactor the public API for the Estimator object.

Goals:

Discussions in #23 and #24

simon-hirsch commented 2 weeks ago

Implemented a first version that follows the API. Changes are mostly cosmetic, i.e. internally most stuff is working as before. Notable changes:

Need to do now:

simon-hirsch commented 2 weeks ago

Code snipped to play around with after building from source on the branch:

import rolch
import numpy as np
from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True)

online_gamlss_new = rolch.OnlineGamlss(
    distribution=rolch.DistributionT(),
    equation={0: "all", 1 : np.arange(4, 8)},
    fit_intercept=True,
    method="lasso",
    scale_inputs=True,
)
online_gamlss_new.fit(X, y)
online_gamlss_new.update(X[[-1], :], y[[-1]])

print(online_gamlss_new.betas)
simon-hirsch commented 2 weeks ago

Overall, I think this is a great improvement to the interface!

Thanks!

My main comment - or question - is about how you are handling pandas and polars. Is all that happens just a conversion to numpy? In this case, you may want to make the conversions extraneous to the estimator. skpro also has default conversions implemented, if you want to use the boilerplate or just the datatypes module, though you probably do not want to take the dependency.

The "nice" part would be if all internal calculations are native, though this may be out of scope.

Yes, I will be all handled by subsetting and conversion to numpy. We have quite a bit of numba supported code in the deeper workings of the package and I don't see a huge benefit in re-writing e.g. coordinate decent to pandas or polars. Another reason is that this keeps the dependency on these libraries rather light, while the main numpy API is rather stable, which makes maintenance easier :)