cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
545 stars 75 forks source link

Cannot pass scoring to cross_val_score #35

Open kavir1698 opened 7 years ago

kavir1698 commented 7 years ago

The following code gives the error below:

using ScikitLearn
using RDatasets

@sk_import linear_model: LinearRegression
using ScikitLearn.CrossValidation: cross_val_score

boston = dataset("MASS", "Boston")

X = Matrix(boston[1:13])
y = Array(boston[:MedV])

lr = LinearRegression()
scores = cross_val_score(lr, X, y,  scoring="r2")

ERROR: ArgumentError: r2 is not a valid scoring value. Valid options are Symbol[:mean_squared_error]
Stacktrace:
 [1] get_scorer(::Symbol) at /home/.julia/v0.6/ScikitLearn/src/scorer.jl:65
 [2] #check_scoring#95(::Bool, ::Function, ::PyCall.PyObject, ::String) at /home/.julia/v0.6/ScikitLearn/src/cro
ss_validation.jl:432
 [3] #cross_val_score#83(::String, ::Int64, ::Int64, ::Int64, ::Void, ::Function, ::PyCall.PyObject, ::Array{Real,2}
, ::Array{Float64,1}) at /home/.julia/v0.6/ScikitLearn/src/cross_validation.jl:276
 [4] (::ScikitLearn.Skcore.#kw##cross_val_score)(::Array{Any,1}, ::ScikitLearn.Skcore.#cross_val_score, ::PyCall.PyO
bject, ::Array{Real,2}, ::Array{Float64,1}) at ./<missing>:0

Any other scoring method will give the same error, except the "mean_squared_error", of course. The same code works in Python.

cstjean commented 7 years ago

Thank you for the report, I seem to have neglected porting that code over. Would you like to make a PR? It should be a straight-forward modification of scorer.jl. You'd have to translate this code from Python. I'm happy to help out with the details if you're interested.

If not, this may work as a work-around until that function is implemented:

using ScikitLearn
@sk_import metrics: r2_score

r2_scorer = ScikitLearn.Skcore.make_scorer(r2_score)
cross_val_score(lr, X, y,  scoring=r2_scorer)
kavir1698 commented 7 years ago

Yes, I will give it a shot.

cstjean commented 7 years ago

This may be useful.