rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction
https://sirus.jl.huijzer.xyz/
MIT License
31 stars 2 forks source link

Replace `Probabilities` by `LeafContent` and define `StableRulesRegressor` #18

Closed rikhuijzer closed 1 year ago

rikhuijzer commented 1 year ago

The PR has also discovered the hard way that the contents of the leaf should be a Vector for both kinds of trees. The core of that is summarized as follows:

"""
Type which holds the values inside a leaf.
For classification, this is a vector of probabilities of each class.
For regression, this is a vector of one element.

!!! note
    Vectors of one element are not as performant as scalars, but the
    alternative here is to have two different types of leafs, which
    results in different types of trees also, which basically
    requires most functions then to become parametric.
"""
const LeafContent = Vector{Float64}

So, don't define two types for the contents.

Furthermore, defines StableRulesRegressor, which goes towards fixing #13.

Predictive performance is poor though:

15×7 DataFrame
 Row │ Dataset   Model                   Hyperparameters    nfolds  AUC     RMS      1.96*SE
     │ String    String                  String             Int64   String  String   String
─────┼───────────────────────────────────────────────────────────────────────────────────────
   1 │ blobs     LGBMClassifier          (;)                    10  0.99             0.01
   2 │ blobs     LGBMClassifier          (max_depth = 2,)       10  0.99             0.01
   3 │ blobs     StableRulesClassifier   (n_trees = 50,)        10  1.00             0.00
   4 │ titanic   LGBMClassifier          (;)                    10  0.87             0.03
   5 │ titanic   LGBMClassifier          (max_depth = 2,)       10  0.85             0.02
   6 │ titanic   StableForestClassifier  (n_trees = 1500,)      10  0.85             0.02
   7 │ titanic   StableRulesClassifier   (n_trees = 1500,)      10  0.84             0.02
   8 │ haberman  LGBMClassifier          (;)                    10  0.71             0.06
   9 │ haberman  LGBMClassifier          (max_depth = 2,)       10  0.67             0.05
  10 │ haberman  StableForestClassifier  (n_trees = 1500,)      10  0.70             0.05
  11 │ haberman  StableRulesClassifier   (n_trees = 1500,)      10  0.67             0.05
  12 │ boston    LGBMRegressor           (;)                    10          0.70     0.06
  13 │ boston    LinearRegressor         (;)                    10          0.70     0.05
  14 │ boston    StableForestRegressor   (;)                    10          0.66     0.07
  15 │ boston    StableRulesRegressor    (n_trees = 1500,)      10          -1400.0  237.45