csinva / imodels

Interpretable ML package πŸ” for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.38k stars 122 forks source link

Rules list cutoffs are not printed in string representations of GreedyRulesListClassifier #169

Open davidefiocco opened 1 year ago

davidefiocco commented 1 year ago

When training GreedyRulesListClassifier on float features, and the fitted classifer clf is printed, cutoff values are not shown, thus making the interpretation of the model a bit confusing. Here's an example:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
import imodels

y = np.random.randint(2, size = 1000)
x1 = y + np.random.rand(1000)*1.2
x2 = y + np.random.rand(1000)*1.2

dataset = pd.DataFrame({"x1": x1, "x2": x2, "y": y})

X_train, X_test, y_train, y_test = train_test_split(dataset[["x1", "x2"]], dataset.y, test_size=0.2, random_state=42)

clf = imodels.GreedyRuleListClassifier(max_depth = 3, criterion = 'gini')
clf.fit(X_train, y_train, feature_names=X_train.columns)
y_pred = clf.predict(X_test)
RocCurveDisplay.from_estimator(clf, X_test, y_test, marker="o")

Trying to render the model with print(clf) yields something along the lines of

> ------------------------------
> Greedy Rule List
> ------------------------------
↓
10.71% risk (800 pts)
    if x1 ==> 99.7% risk (361 pts)
↓
2.0% risk (439 pts)
    if x2 ==> 100.0% risk (39 pts)
↓
0.28% risk (400 pts)
    if x1 ==> 16.3% risk (43 pts)

which I find confusing because x1 and x2 are floats, not booleans. clf.rules_ are instead

[{'col': 'x1',
  'index_col': 0,
  'cutoff': 1.1447782516479492,
  'val': 0.09429280397022333,
  'flip': False,
  'val_right': 0.9672544080604534,
  'num_pts': 800,
  'num_pts_right': 397},
 {'col': 'x2',
  'index_col': 1,
  'cutoff': 1.2083932757377625,
  'val': 0.01881720430107527,
  'flip': False,
  'val_right': 1.0,
  'num_pts': 403,
  'num_pts_right': 31},
 {'col': 'x1',
  'index_col': 0,
  'cutoff': 1.0007766485214233,
  'val': 0.0,
  'flip': False,
  'val_right': 0.125,
  'num_pts': 372,
  'num_pts_right': 56},
 {'val': 0.0, 'num_pts': 316}] 

and contain a cutoff that is useful for model interpretation. I don't know exactly what would be the desired intended behavior, as at the moment the code starting at https://github.com/csinva/imodels/blob/1243240fec3aae33852ba680ba6aea66a4f86ca7/imodels/rule_list/greedy_rule_list.py#L143-L184 contains commented chunks (also with colors, but not used).

davidefiocco commented 1 year ago

A possible improvement would be reworking the __str__ representation as

    def __str__(self):
        '''Print out the list in a nice way
        '''
        header = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
        footer = '> ------------------------------\n'
        rule_template = '> {condition} => {risk}% risk ({num_pts} pts)\n'

        s = header
        for i in range(len(self.rules_)):

            rule = self.rules_[i]

            condition = 'else'
            risk = (100 * rule['val']).round(2)
            num_pts = rule['num_pts']

            if 'col' in rule:
                predicate = '>=' if not rule['flip'] else '<'
                if i == 0:
                    condition = f"if {rule['col']} {predicate} {rule['cutoff']}"
                else:
                    condition = f"else if {rule['col']} {predicate} {rule['cutoff']}"

                risk = (100 * rule['val_right']).round(2)
                num_pts = rule['num_pts_right']

            s += rule_template.format(
                condition=condition,
                risk=risk,
                num_pts=num_pts
            )

        s += footer
        return s

Which would render rules such as

[{'col': 'x2',
  'index_col': 1,
  'cutoff': 0.1395193189382553,
  'val': 0.04092071611253197,
  'flip': True,
  'val_right': 0.9315403422982885,
  'num_pts': 800,
  'num_pts_right': 409},
 {'col': 'x1',
  'index_col': 0,
  'cutoff': 0.0753365887212567,
  'val': 0.010554089709762533,
  'flip': False,
  'val_right': 1.0,
  'num_pts': 391,
  'num_pts_right': 12},
 {'col': 'x2',
  'index_col': 1,
  'cutoff': 0.19506534934043884,
  'val': 0.0,
  'flip': True,
  'val_right': 0.16666666666666666,
  'num_pts': 379,
  'num_pts_right': 24},
 {'val': 0.0, 'num_pts': 355}]

as

> ------------------------------
> Greedy Rule List
> ------------------------------
> if x2 < 0.1395193189382553 => 93.15% risk (409 pts)
> else if x1 >= 0.0753365887212567 => 100.0% risk (12 pts)
> else if x2 < 0.19506534934043884 => 16.67% risk (24 pts)
> else => 0.0% risk (355 pts)
> ------------------------------
csinva commented 1 year ago

Thanks, this is a nice fix!

I'll work on making it so that it displays like this if the feature is continuous-valued and keeps the original behavior for non-continuous features. Probably also worth rounding the cutoff value to ~3 decimal places.