cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
547 stars 75 forks source link

ScikitLearn.fit!() issue with y_train decimal points #112

Closed CharisKomos closed 2 years ago

CharisKomos commented 2 years ago

Hello, please consider the code below where I am trying to fit a ScikitLearn.jl model. In this case I use RandomForestClassifier but this issue appears for LogisticRegression and DecisionTreeClassifier too (that is all I have tried so far).

I manually construct the datasets for the sake of demonstrating the issue.

model_obj = RandomForestClassifier();
train = rand(1:0.01:100,20,5)
x_train = train[1:15,1:4]
y_train = train[1:15,5]
x_test  = train[16:20,1:4]
y_test  = train[16:20,5]

ScikitLearn.fit!(model_obj, x_train, y_train)
y_train = round.(y_train)
ScikitLearn.fit!(model_obj, x_train, y_train)

y_train[1] += 1.000000
ScikitLearn.fit!(model_obj, x_train, y_train)

y_train[1] += 1.0000001
ScikitLearn.fit!(model_obj, x_train, y_train)

When I try to fit the model with Vector{Int64}, Vector{Float64} or even Vector{Any} type of data where I mix floats with strings in y_train, I can achieve the fit and the function works well. The issue is that if there is a single value which has a decimal number other than 0, then an error is thrown (please see relative screenshot when the code was run).

1

The whole error message is :

ERROR: PyError ($(Expr(:escape, :(ccall(#= C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'ValueError'>
ValueError("Unknown label type: 'continuous'")
  File "C:\Users\C.Komodromos\.julia\Conda_env\lib\site-packages\sklearn\ensemble\_forest.py", line 367, in fit
    y, expanded_class_weight = self._validate_y_class_weight(y)
  File "C:\Users\C.Komodromos\.julia\Conda_env\lib\site-packages\sklearn\ensemble\_forest.py", line 734, in _validate_y_class_weight
    check_classification_targets(y)
  File "C:\Users\C.Komodromos\.julia\Conda_env\lib\site-packages\sklearn\utils\multiclass.py", line 197, in check_classification_targets
    raise ValueError("Unknown label type: %r" % y_type)

Stacktrace:
 [1] pyerr_check at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\exception.jl:62 [inlined]
 [2] pyerr_check at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\exception.jl:66 [inlined]
 [3] _handle_error(::String) at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\exception.jl:83
 [4] macro expansion at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\exception.jl:97 [inlined]
 [5] #109 at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:43 [inlined]
 [6] disable_sigint at .\c.jl:446 [inlined]
 [7] __pycall! at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:42 [inlined]
 [8] _pycall!(::PyCall.PyObject, ::PyCall.PyObject, ::Tuple{Array{Float64,2},Array{Float64,1}}, ::Int64, ::Ptr{Nothing}) at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:29
 [9] _pycall!(::PyCall.PyObject, ::PyCall.PyObject, ::Tuple{Array{Float64,2},Array{Float64,1}}, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:11
 [10] (::PyCall.PyObject)(::Array{Float64,2}, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:86
 [11] (::PyCall.PyObject)(::Array{Float64,2}, ::Vararg{Any,N} where N) at C:\Users\C.Komodromos\.julia\packages\PyCall\tqyST\src\pyfncall.jl:86
 [12] fit!(::PyCall.PyObject, ::Array{Float64,2}, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\C.Komodromos\.julia\packages\ScikitLearn\NJwUf\src\Skcore.jl:102
 [13] fit!(::PyCall.PyObject, ::Array{Float64,2}, ::Array{Float64,1}) at C:\Users\C.Komodromos\.julia\packages\ScikitLearn\NJwUf\src\Skcore.jl:102
 [14] top-level scope at none:1

What am I missing here? Any help is appreciated.