+1 vote
7 views
in Python

How can I extract the decision path as a textual list from a trained tree in a decision tree ?

Something similar to this:

if P>0.4 then if Q<0.2 then if R>0.8 then class='A'

by (10.9k points)
edited

To extract the decision rules from scikit-learn decision-tree try this code below:

from sklearn.tree import _tree

def tree_to_code(tree,fnames):

tree_ = tree.tree_

fnames = [

fnames[n] if n != _tree.TREE_UNDEFINED else "undefined!"

for n in tree_.feature

]

print "def tree({}):".format(", ".join(fnames))

def recurse(node, depth):

ind = "  " * depth

if tree_.feature[node] != _tree.TREE_UNDEFINED:

name = fnames[node]

th = tree_.threshold[node]

print "{}if {} <= {}:".format(ind, name, th)

recurse(tree_.children_left[node], depth + 1)

print "{}else:  # if {} > {}".format(ind, name, th)

recurse(tree_.children_right[node], depth + 1)

else:

print "{}return {}".format(ind, tree_.value[node])

recurse(0, 1)

The above code prints a valid Python function.

Example: output of a tree which is trying to return a number between 0 to 10

def tree(m0):

if m0 <= 6.0:

if m0 <= 1.5:

return [[ 0.]]

else:

if m0 <= 4.5:

if m0 <= 3.5:

return [[ 3.]]

else:

return [[ 4.]]

else:

return [[ 5.]]

else:

if m0 <= 8.5:

if m0 <= 7.5:

return [[ 7.]]

else:

return [[ 8.]]

else:

return [[ 9.]]

by (41.4k points)

You can use your own function to extract the rules from the decision trees created by sklearn:

import pandas as pd

import numpy as np

from sklearn.tree import DecisionTreeClassifier

# dummy data:

df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree

dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)

dt.fit(df.ix[:,:2], df.dv)

You can use the following video tutorials to clear all your doubts:-