JuliaAI / MLJScikitLearnInterface.jl

MLJ Interface for ScikitLearn.jl
Other
12 stars 6 forks source link

Passing return_std to predict #46

Open evolbio opened 1 year ago

evolbio commented 1 year ago

Discussion at discourse and ScikitLearn.jl issue suggested raising an issue here, with background and suggestions copied:

From me @evolbio: Various Scikitlearn models accept return_std=true when calling predict, for example BayesianRidgeRegressor, see this example. For example, with a BayesianRidgeRegressor or similar machine, I would like to call y_predict, y_std = predict(machine, X, return_std=true) I am using MLJ to make calls through ScikitLearn.jl. I have looked through ScikitLearn.jl and MLJScikitLearnInterface.jl and do not see anyway to make this work, but maybe I am missing something simple like the right way to pass additional arguments? Thanks.

and reply from @tlienart: You’re not missing something, there’s currently no way to pass that argument. It might be good to open an issue at MLJScikitLearnInterface to discuss this (and you could paste what follows).

I doubt that MLJ’s predict signature will be adapted to match this one but I’ll let @ablaom or @samuel_okon discuss that).

What could work is to pass the return_std as a new field of BayesianRidgeRegressor here MLJScikitLearnInterface.jl/linear-regressors.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

then pick that up at predict time here MLJScikitLearnInterface.jl/macros.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

this would also require ScikitLearn.jl to allow passing a return_std=true to predict, that might also require opening an issue there cc @cstjean

ablaom commented 1 year ago

@evolbio Thanks for posting. Will take a look into this alert Jan 8th. Ping me if I appear to forget.

ablaom commented 1 year ago

@evolbio Yes, @tlienart is correct.

As currently implemented, the MLJ interface for BayesianRidgeRegression, views the model as making deterministic predictions (model subtypes Deterministic) and there is no way to access the standard deviations. The MLJ idiomatic way to rectify the situation is to re-implement the interface to regard the model as Probabilistic, in which case predictions would have the form of distributions, specifically of type Distributions.Normal.

I will post an issue, which will be cross-referenced below, to initiate such a change. In the mean time, perhaps the pure julia package SossMLJ.jl addresses your needs??

evolbio commented 1 year ago

@ablaom Thanks very much for following up. I made a direct call to ScikitLearn, which worked well enough for what I needed. I agree that in the long run there are some Deterministic models which would gain from returning Probabilistic predictions and so look forward to eventual modifications.