csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.35k stars 120 forks source link

Add max_tree hyperparameter #153

Closed mepland closed 1 year ago

mepland commented 1 year ago

Addresses https://github.com/csinva/imodels/issues/152

mepland commented 1 year ago

@csinva we should be ready to merge!

In FIGS_viz_demo.ipynb you can see how setting max_trees=1 works as expected:

model_figs_1 = FIGSClassifier(max_rules=7, max_trees=1)

> ------------------------------
> FIGS-Fast Interpretable Greedy-Tree Sums:
>   Predictions are made by summing the "Val" reached by traversing each tree.
>   For classifiers, a sigmoid function is then applied to the sum.
> ------------------------------
Glucose concentration test <= 99.500 (Tree #0 root)
    Val: 0.068 (leaf)
    Glucose concentration test <= 168.500 (split)
        #Pregnant <= 6.500 (split)
            Body mass index <= 30.850 (split)
                Val: 0.065 (leaf)
                Blood pressure(mmHg) <= 67.000 (split)
                    Val: 0.714 (leaf)
                    Diabetes pedigree function <= 0.282 (split)
                        Val: 0.000 (leaf)
                        Val: 0.474 (leaf)
            Diabetes pedigree function <= 0.263 (split)
                Val: 0.333 (leaf)
                Val: 0.792 (leaf)
        Val: 0.810 (leaf)

As a comparison, here is the same dataset, but without the max_trees hyperparameter (really max_trees=3, but it is not reached):

model_figs = FIGSClassifier(max_rules=7, max_trees=3)

> ------------------------------
> FIGS-Fast Interpretable Greedy-Tree Sums:
>   Predictions are made by summing the "Val" reached by traversing each tree.
>   For classifiers, a sigmoid function is then applied to the sum.
> ------------------------------
Glucose concentration test <= 99.500 (Tree #0 root)
    Val: 0.068 (leaf)
    Glucose concentration test <= 168.500 (split)
        #Pregnant <= 6.500 (split)
            Body mass index <= 30.850 (split)
                Val: 0.065 (leaf)
                Blood pressure(mmHg) <= 67.000 (split)
                    Val: 0.705 (leaf)
                    Val: 0.303 (leaf)
            Val: 0.639 (leaf)
        Blood pressure(mmHg) <= 93.000 (split)
            Val: 0.860 (leaf)
            Val: -0.009 (leaf)

    +
Diabetes pedigree function <= 0.404 (Tree #1 root)
    Val: -0.088 (leaf)
    Val: 0.106 (leaf)
csinva commented 1 year ago

Looks great!