rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
31 stars 2 forks source link

`StableRulesRegressor` and `StableForestRegressor` models should `<:Deterministic` #40

Closed OkonSamuel closed 1 year ago

OkonSamuel commented 1 year ago

@rikhuijzer The StableRulesRegressor and StableForestRegressor models must subtype Deterministic not Probabilistic as calling predict on these model types doesn't return a probabilistic distribution as shown in the code below.

julia> using MLJ, SIRUS

julia> mach1 = machine(StableRulesRegressor(), table(X1), y1);

julia> using MLJ, SIRUS

julia> X1, y1 = table(rand(200,10)), rand(200);

julia> mach1 = machine(StableRulesRegressor(), X1, y1);

julia> mach2 = machine(StableForestRegressor(), X1, y1);

julia> fit!(mach1);
[ Info: Training machine(StableRulesRegressor(rng = Random.TaskLocalRNG(), …), …).

julia> fit!(mach2);
[ Info: Training machine(StableForestRegressor(rng = Random.TaskLocalRNG(), …), …).

julia> predict(mach1, X1)
200-element Vector{Float64}:
 0.4280651626586914
 0.4280651626586914
 ⋮
 0.4280651626586914
 0.4964567642211914
 0.5084392166137696

julia> predict(mach2, X1)
200-element Vector{Float64}:
 0.45261388228835897
 0.47241962903510204
 0.45548748878311535
 0.4736496795216281
 ⋮
 0.49846966238407786
 0.47304117166423704
 0.5132681542982771
 0.5329223010581148

The relevant code lines are https://github.com/rikhuijzer/SIRUS.jl/blob/main/src/mlj.jl#L69 https://github.com/rikhuijzer/SIRUS.jl/blob/main/src/mlj.jl#L78

rikhuijzer commented 1 year ago

I've tried to update it, but the AUC evaluations fail because that measure only supports probabilistic types. Maybe we should stick to probabilistic for the SIRUS models? They all depend on randomness inside the forests, so they are not fully deterministic (they are when picking a rng, but not without).

OkonSamuel commented 1 year ago

I'm confused as to why you need auc here? Is it something used in fitting the model? Can you shed more light. Maybe we can wrap the output in a simple distribution wrapper. @ablaom what do you think?

rikhuijzer commented 1 year ago

Oh wait. I have accidentally updated a classifier. My bad! I'll try again to fix this!