cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
547 stars 75 forks source link

bug in `StratifiedKFold` when `shuffle=true` #109

Open SimonEnsemble opened 2 years ago

SimonEnsemble commented 2 years ago

ScikitLearn v0.6.4

Julia Version 1.7.1 Commit ac5cc99908 (2021-12-22 19:35 UTC) Platform Info: OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Core(TM) i5-9400 CPU @ 2.90GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-12.0.1 (ORCJIT, skylake)

using ScikitLearn.CrossValidation: train_test_split, StratifiedKFold, KFold

kf = KFold(10, n_folds=5, shuffle=true) # all is well

skf = StratifiedKFold([rand([-1, 1]) for i = 1:20], n_folds=5) # all is well

skf = StratifiedKFold([rand([-1, 1]) for i = 1:20], n_folds=5, shuffle=true) # error below
ERROR: ArgumentError: Random._GLOBAL_RNG() cannot be used to seed a MersenneTwister
Stacktrace:
 [1] check_random_state(seed::Random._GLOBAL_RNG)
   @ ScikitLearn.Skcore ~/.julia/packages/ScikitLearn/ssekP/src/cross_validation.jl:19
 [2] KFold(n::Int64; n_folds::Int64, shuffle::Bool, random_state::Random._GLOBAL_RNG)
   @ ScikitLearn.Skcore ~/.julia/packages/ScikitLearn/ssekP/src/cross_validation.jl:120
 [3] (::ScikitLearn.Skcore.var"#74#76"{Int64, Bool})(c::Int64)
   @ ScikitLearn.Skcore ./none:0
 [4] iterate
   @ ./generator.jl:47 [inlined]
 [5] collect(itr::Base.Generator{Vector{Int64}, ScikitLearn.Skcore.var"#74#76"{Int64, Bool}})
   @ Base ./array.jl:724
 [6] StratifiedKFold(y::Vector{Int64}; n_folds::Int64, shuffle::Bool, random_state::Nothing)
   @ ScikitLearn.Skcore ~/.julia/packages/ScikitLearn/ssekP/src/cross_validation.jl:160
 [7] top-level scope
   @ REPL[20]:1