cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
547 stars 75 forks source link

Passing return_std to predict #117

Closed evolbio closed 1 year ago

evolbio commented 1 year ago

Various Scikitlearn models accept return_std=true when calling predict, for example BayesianRidgeRegressor, see this example. For example, with a BayesianRidgeRegressor or similar machine, I would like to call y_predict, y_std = predict(machine, X, return_std=true) I am using MLJ to make calls through ScikitLearn.jl. I have looked through ScikitLearn.jl and MLJScikitLearnInterface.jl and do not see anyway to make this work, but maybe I am missing something simple like the right way to pass additional arguments? Thanks.

evolbio commented 1 year ago

Very helpful response from @tlienart at discourse, copied here:

You’re not missing something, there’s currently no way to pass that argument. It might be good to open an issue at MLJScikitLearnInterface to discuss this (and you could paste what follows).

I doubt that MLJ’s predict signature will be adapted to match this one but I’ll let @ablaom or @samuel_okon discuss that).

What could work is to pass the return_std as a new field of BayesianRidgeRegressor here MLJScikitLearnInterface.jl/linear-regressors.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

then pick that up at predict time here MLJScikitLearnInterface.jl/macros.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

this would also require ScikitLearn.jl to allow passing a return_std=true to predict, that might also require opening an issue there cc @cstjean

cstjean commented 1 year ago

Hi @evolbio. We're generally aiming to match the scikit-learn python interface, so supporting return_std=true would be perfectly in line. It might be as simple as adding a ; kwargs... somewhere. It would be a welcome contribution.

tlienart commented 1 year ago

This can be closed, it already works (I should have checked). This part of the code:

https://github.com/cstjean/ScikitLearn.jl/blob/9f15da6c1daf49b30ed59c1dd0a5a30ec97ac7ab/src/Skcore.jl#L70-L104

shows that kwargs can already be passed to predict (and to all other api functions) without restrictions (L101-102) and, indeed:

using ScikitLearn
@sk_import linear_model: BayesianRidge
X, y = randn(100,2), randn(100)
reg = BayesianRidge()
fit!(reg, X, y)
ym, ys = predict(reg, X, return_std=true)

just works, kudos to @cstjean for thinking generic from the start.