rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
30 stars 2 forks source link

Models do not handle missing classes on subsampling #43

Open ablaom opened 1 year ago

ablaom commented 1 year ago

When running MLJ integration tests:

using MLJTestIntegration
import SIRUS

X, y = MLJTestIntegration.make_binary()
MLJTestIntegration.test(
    [SIRUS.StableForestClassifier,],
    X,
    y,
    mod = @__MODULE__,
    level=4,
    throw=true,
)
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B"] to [0.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
ERROR: DomainError with Value B not in pool. :

Stacktrace:
 [1] attempt(f::MLJTestIntegration.var"#5#8"{MLJBase.LogLoss{Float64}, SIRUS.MLJImplementation.StableForestClassifier, Vector{ComputationalResources.CPU1{Nothing}}, Tuple{NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, CategoricalArrays.CategoricalVector{String, UInt32, String, CategoricalArrays.CategoricalValue{String, UInt32}, Union{}}}}, message::String; throw::Bool)                                                                      
   @ MLJTestInterface ~/.julia/packages/MLJTestInterface/K1YSy/src/attemptors.jl:17
 [2] evaluation(::MLJBase.LogLoss{Float64}, ::SIRUS.MLJImplementation.StableForestClassifier, ::Vector{ComputationalResources.CPU1{Nothing}}, ::NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, ::Vararg{Any}; throw::Bool, verbosity::Int64)
   @ MLJTestIntegration ~/.julia/packages/MLJTestIntegration/J5lEw/src/attemptors.jl:24
 [3] test(::Vector{DataType}, ::NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, ::Vararg{Any}; mod::Module, level::Int64, throw::Bool, verbosity::Int64)
   @ MLJTestIntegration ~/.julia/packages/MLJTestIntegration/J5lEw/src/test.jl:312
 [4] top-level scope
   @ REPL[43]:1

caused by: DomainError with Value B not in pool. :

The failing test is a cross-validation test in which some folds are presumably missing one of the two classes present in the test set, which is very small. Ideally, a model should handle this eventuality. For example, in the binary case tested here, the prediction could be always the single class present in training target.

The StableRulesClassifier has the same issue.

I admit this is a bit of small-data corner case, but it would be great to address. I will need to remove SIRUS classifiers from MLJTestIntegration tests, pending resolution of this issue. All other MLJ classifiers do handle this corner case (outside of our ScikitLearn models).

rikhuijzer commented 1 year ago

Thanks for opening the issue. It is indeed a weak point, which I haven't figured out how to solve yet.

Ideally, a model should handle this eventuality.

So yes for the binary case, I could switch to returning a single class. Do you also know how I can figure out which classes to use in multiple folds? I'll try to look around in other packages and add another comment here if I find a solution.

Related to https://github.com/rikhuijzer/SIRUS.jl/issues/25.

EDIT: This problem is likely caused by how the UnivariateFinite is constructed in src/forest.jl. It should (re)use the right pool or use MLJXGBoostInterface or so as an example.