csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.35k stars 120 forks source link

Added Gini Importances #154

Closed mepland closed 1 year ago

mepland commented 1 year ago

Hi @csinva how do the new Gini importances look?

I based the calculation off sklearn's code from here and here, though it needed to be made recursive as we do not have arrays of all the nodes and their properties.

There is a demo of the new code in the FIGS_viz_demo.ipynb notebook. I am a bit concerned with the None impurity in the root node of the second tree:

node_id: 0, left.node_id: 1, right.node_id: 2, impurity: None

I filled it with 0 for the calculation for now:

                importance_data_tree[node.feature] += (
                    np.sum(node.value_sklearn) * (node.impurity if node.impurity is not None else 0.) -
                    np.sum(node.left.value_sklearn) * node.left.impurity -
                    np.sum(node.right.value_sklearn) * node.right.impurity
                )

Is None expected if the tree has just one split?

Also, after taking the mean and normalizing most of the importances are negative. I think this is fine, as we just care about the relative order of the features, but wanted to get your opinion as well: image

BTW I noticed that we have an unused variable in plot():

criterion = "squared_error" if isinstance(self, RegressorMixin) else "gini"

Is this need for anything, or should we delete it?

mepland commented 1 year ago

Addresses https://github.com/csinva/imodels/issues/127 for FIGS.

mepland commented 1 year ago

Also, I didn't add support for sample_weight during this initial pass: https://github.com/csinva/imodels/issues/89

csinva commented 1 year ago
criterion = "squared_error" if isinstance(self, RegressorMixin) else "gini"

Yeah good catch we can drop this.

csinva commented 1 year ago

Also, after taking the mean and normalizing most of the importances are negative. I think this is fine, as we just care about the relative order of the features, but wanted to get your opinion as well:

I think we should make all the importances positive (just flip the sign at the end), to stay consistent with sklearn.

csinva commented 1 year ago

Your calculation looks right to me -- I actually don't understand the None impurity issue. It seems to me that every new added stump should be assigned an impurity (L181) and if it's None we should see an error or early stop (L235).

mepland commented 1 year ago

I think we should make all the importances positive (just flip the sign at the end), to stay consistent with sklearn.

@csinva if we do that, multiplying by -1, the order would flip too, which is not what we want, right? Unless we think that when normalizing, by that negative number, we basically already flipped the order. I can see that actually, Glucose concentration test is used twice at the start of the first tree so it should be the most important.

Maybe I should normalize by avg_feature_importances / abs(np.sum(avg_feature_importances))? But then the sum is -1... image

mepland commented 1 year ago

BTW sklearn has an additional normalization step in the code, but it is not ultimately used.

mepland commented 1 year ago

Your calculation looks right to me -- I actually don't understand the None impurity issue. It seems to me that every new added stump should be assigned an impurity (L181) and if it's None we should see an error or early stop (L235).

Actually, I think we should try to figure this out first as the negative importance for Diabetes pedigree function is basically being caused by that None impurity being set as 0. Fixing this should also fix the negatives issue.

mepland commented 1 year ago

@csinva I did some debugging and I think that the None impurity node at the root of the second tree is coming from this line:

                # add new root potential node
                node_new_root = Node(is_root=True, idxs=np.ones(X.shape[0], dtype=bool),
                                     tree_num=-1)
                potential_splits.append(node_new_root)

The default param for impurity in Node() is None and then it is never filled with a real value. I'll paste my debugging output below so you can see it in practice.

How would you like to fix this? Seems to be a problem in the underlying FIGS fitting algo itself, not the impurity or importance calculations.

Debugging output:

DEBUG starting new tree:
tree_num: -1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.19996537396121883
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2307098765432099
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.19996537396121883
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06035379812695109
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24395061728395062
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2307098765432099
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
DEBUG starting new tree:
tree_num: -1, node_id: None, left.node_id: None, right.node_id: None, impurity: None
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.19996537396121883
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06035379812695109
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24395061728395062
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2307098765432099
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: None
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.12839368044187147
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1477618114868653
DEBUG in progress tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.19996537396121883
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06035379812695109
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24395061728395062
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.21345881926419968
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.18349552685909898
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2307098765432099
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: None
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.12839368044187147
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1477618114868653
DEBUG fitted tree
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2239312065972222
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06320022981901752
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24828989767652213
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2385204081632653
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.19996537396121883
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.06035379812695109
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.24395061728395062
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.21345881926419968
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.18349552685909898
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.2307098765432099
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1541950113378685
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.09754575404204457
tree_num: 0, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.009377186611138793
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: None
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.12839368044187147
tree_num: 1, node_id: None, left.node_id: None, right.node_id: None, impurity: 0.1477618114868653
csinva commented 1 year ago

Ah you're right, when that new potentital node is updated at this line, impurity_reduction is updated but not impurity (since it wasn't used up till now). It's an easy fix, let me push it now.

csinva commented 1 year ago

Okay I think it should be fixed now

mepland commented 1 year ago

All set! image

csinva commented 1 year ago

Perfect!

mepland commented 1 year ago

@csinva ready to merge!

mepland commented 1 year ago

@csinva BTW I want to make one more PR to clean up that notebook and fix spelling mistakes, but then we can make a new release if you'd like.