rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
29 stars 2 forks source link

[Bug] MLJ interfaces do not work for generic tables #78

Open ablaom opened 6 months ago

ablaom commented 6 months ago

I have updated MLJTestInterface.jl to test models on row-based tables and discovered this bug:

julia> using SIRUS, MLJTestInterface, MLJBase

julia> X, y = MLJTestInterface.make_regression(row_table=true);

julia> machine(SIRUS.StableRulesRegressor(), X, y) |> fit!
[ Info: Training machine(StableRulesRegressor(rng = Random.TaskLocalRNG(), …), …).
┌ Error: Problem fitting the machine machine(StableRulesRegressor(rng = Random.TaskLocalRNG(), …), …). 
└ @ MLJBase ~/.julia/packages/MLJBase/fEiP2/src/machines.jl:682
[ Info: Running type checks... 
[ Info: Type checks okay. 
ERROR: TaskFailedException

    nested task error: BoundsError: attempt to access 0-element Vector{String} at index [2]
    Stacktrace:
     [1] getindex
       @ ./essentials.jl:13 [inlined]
     [2] _split(rng::Random.Xoshiro, algo::SIRUS.Regression, X::Matrix{…}, y::SubArray{…}, classes::Vector{…}, colnms::Vector{…}, cps::Vector{…}; max_split_candidates::Int64)         
       @ SIRUS ~/.julia/packages/SIRUS/6Paa4/src/forest.jl:133
     [3] _split
       @ ~/.julia/packages/SIRUS/6Paa4/src/forest.jl:91 [inlined]
     [4] _tree!(rng::Random.Xoshiro, algo::SIRUS.Regression, mask::Vector{…}, X::Matrix{…}, 
y::SubArray{…}, classes::Vector{…}, colnms::Vector{…}; max_split_candidates::Int64, depth::Int64, max_depth::Int64, q::Int64, cps::Vector{…}, min_data_in_leaf::Int64)                 
       @ SIRUS ~/.julia/packages/SIRUS/6Paa4/src/forest.jl:198
     [5] _tree!
       @ ~/.julia/packages/SIRUS/6Paa4/src/forest.jl:176 [inlined]
     [6] macro expansion
       @ ~/.julia/packages/SIRUS/6Paa4/src/forest.jl:334 [inlined]
     [7] (::SIRUS.var"#33#threadsfor_fun#7"{SIRUS.var"#33#threadsfor_fun#5#8"{…}})(tid::Int64; onethread::Bool)                                                                     
       @ SIRUS ./threadingconstructs.jl:214
     [8] #33#threadsfor_fun
       @ SIRUS ./threadingconstructs.jl:181 [inlined]
     [9] (::Base.Threads.var"#1#2"{SIRUS.var"#33#threadsfor_fun#7"{SIRUS.var"#33#threadsfor_fun#5#8"{…}}, Int64})()
       @ Base.Threads ./threadingconstructs.jl:153

...and 11 more exceptions.

There is a similar bug for StableForestRegressor and probably for all SIRUS models.

You may want to start by duplicating tests like this one with the added option row_table=true in make_regression() to catch the bug.

rikhuijzer commented 6 months ago

Thanks for reporting this Anthony! I have other things to do with a higher priority in the near future, but may fix this later.