UBTC / gopy

0 stars 0 forks source link

how to extract the decision rules from scikit-learn decision-tree? #1

Open WMJi opened 6 years ago

WMJi commented 6 years ago

Can I extract the underlying decision-rules (or 'decision paths') from a trained tree in a decision tree - as a textual list ? something like: "if A>0.4 then if B<0.2 then if C>0.8 then class='X' etc...

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

print "{}return {}".format(indent, tree_.value[node]) should be changed to print "{}return {}".format(indent, np.argmax(tree_.value[node][0])) for the function to return the class index.

ref

WMJi commented 6 years ago

I created my own function to extract the rules from the decision trees created by sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

The sets of tuples below contain everything I need to create SAS if/then/else statements. I do not like using do blocks in SAS which is why I create logic describing a node's entire path. The single integer after the tuples is the ID of the terminal node in a path. All of the preceding tuples combine to create that node.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
WMJi commented 6 years ago

Here is a function, printing rules of a scikit-learn decision tree under python 3 and with offsets for conditional blocks to make the structure more readable:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)