cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
546 stars 75 forks source link

Add test for r2 score in cross_val_score. Pull request/2e09aad8. #40

Open kavir1698 opened 6 years ago

cstjean commented 6 years ago

Thank you. Sorry that I haven't had the time to look at this in detail, but will do when I can. It's a bummer that Travis is not working with ScikitLearn.jl at the moment.

kavir1698 commented 6 years ago

No rush, thank you for the good work.

cstjean commented 6 years ago

Sorry for the delay. Travis is not working for this package right now, so I had to test this PR locally. To complicate matters further, you need the latest version of some packages (use Pkg.checkout) to have the tests complete, notably: DataFrames and LowRankModels. To test the package, use Pkg.test("ScikitLearn"). Make sure that DataFrames.jl is at a version above 0.11.0.

When I test your pull request, I get this error:

ERROR: LoadError: MethodError: no method matching mean(::Array{Union{Float64, Missings.Missing},1}, ::StatsBase.Weights{Float64,Float64,Array{Float64,1}}, ::Int64)
Closest candidates are:
  mean(::AbstractArray{T<:Number,N} where N, ::StatsBase.AbstractWeights{W<:Real,T,V} where V<:AbstractArray{T,1} where T<:Real, ::Int64) where {T<:Number, W<:Real} at /Users/cedric/.julia/v0.6/StatsBase/src/weights.jl:467
  mean(::AbstractArray, ::StatsBase.AbstractWeights) at /Users/cedric/.julia/v0.6/StatsBase/src/weights.jl:453
  mean(::AbstractArray{T,N} where N, ::Any) where T at statistics.jl:57
  ...
Stacktrace:
 [1] #r2_score#58(::Void, ::String, ::Function, ::Array{Union{Float64, Missings.Missing},1}, ::Array{Float64,1}) at /Users/cedric/.julia/v0.6/ScikitLearn/src/scorer.jl:91
 [2] r2_score(::Array{Union{Float64, Missings.Missing},1}, ::Array{Float64,1}) at /Users/cedric/.julia/v0.6/ScikitLearn/src/scorer.jl:79
 [3] #call#67(::Void, ::ScikitLearn.Skcore.PredictScorer, ::PyCall.PyObject, ::Array{Union{Float64, Missings.Missing},2}, ::Array{Union{Float64, Missings.Missing},1}) at /Users/cedric/.julia/v0.6/ScikitLearn/src/scorer.jl:166
 [4] _score(::PyCall.PyObject, ::Array{Union{Float64, Missings.Missing},2}, ::Array{Union{Float64, Missings.Missing},1}, ::ScikitLearn.Skcore.PredictScorer) at /Users/cedric/.julia/v0.6/ScikitLearn/src/cross_validation.jl:651
 [5] #_fit_and_score#105(::Bool, ::Bool, ::String, ::Function, ::PyCall.PyObject, ::Array{Union{Float64, Missings.Missing},2}, ::Array{Union{Float64, Missings.Missing},1}, ::ScikitLearn.Skcore.PredictScorer, ::Array{Int64,1}, ::Array{Int64,1}, ::Int64, ::Void, ::Void) at /Users/cedric/.julia/v0.6/ScikitLearn/src/cross_validation.jl:574
 [6] (::ScikitLearn.Skcore.##93#94{Int64,Void,PyCall.PyObject,Array{Union{Float64, Missings.Missing},2},Array{Union{Float64, Missings.Missing},1}})(::Tuple{Array{Int64,1},Array{Int64,1}}) at ./<missing>:0
 [7] copy!(::Array{Float64,1}, ::Base.Generator{Array{Tuple{Array{Int64,1},Array{Int64,1}},1},ScikitLearn.Skcore.##93#94{Int64,Void,PyCall.PyObject,Array{Union{Float64, Missings.Missing},2},Array{Union{Float64, Missings.Missing},1}}}) at ./abstractarray.jl:572
 [8] _collect(::Type{Float64}, ::Base.Generator{Array{Tuple{Array{Int64,1},Array{Int64,1}},1},ScikitLearn.Skcore.##93#94{Int64,Void,PyCall.PyObject,Array{Union{Float64, Missings.Missing},2},Array{Union{Float64, Missings.Missing},1}}}, ::Base.HasShape) at ./array.jl:363
 [9] #cross_val_score#92(::String, ::Int64, ::Int64, ::Int64, ::Void, ::Function, ::PyCall.PyObject, ::Array{Union{Float64, Missings.Missing},2}, ::Array{Union{Float64, Missings.Missing},1}) at /Users/cedric/.julia/v0.6/ScikitLearn/src/cross_validation.jl:279
 [10] (::ScikitLearn.Skcore.#kw##cross_val_score)(::Array{Any,1}, ::ScikitLearn.Skcore.#cross_val_score, ::PyCall.PyObject, ::Array{Union{Float64, Missings.Missing},2}, ::Array{Union{Float64, Missings.Missing},1}) at ./<missing>:0
 [11] test_corss_val_r2_score() at /Users/cedric/.julia/v0.6/ScikitLearn/test/test_crossvalidation.jl:234
 [12] all_test_crossvalidation() at /Users/cedric/.julia/v0.6/ScikitLearn/test/test_crossvalidation.jl:243
 [13] include_from_node1(::String) at ./loading.jl:569
 [14] include(::String) at ./sysimg.jl:14
 [15] process_options(::Base.JLOptions) at ./client.jl:305
 [16] _start() at ./client.jl:371
while loading /Users/cedric/.julia/v0.6/ScikitLearn/test/runtests.jl, in expression starting on line 19

Could you please fix it?

cstjean commented 6 years ago

You may get an obscure error with PyPlot/PyCall. There's no fix for those at the moment. Just ignore it.