Intellipaat Back

Explore Courses Blog Tutorials Interview Questions
+1 vote
7 views
in Python by (250 points)

How can I extract the decision path as a textual list from a trained tree in a decision tree ?

Something similar to this:

if P>0.4 then if Q<0.2 then if R>0.8 then class='A'

2 Answers

+5 votes
by (10.9k points)
edited by

To extract the decision rules from scikit-learn decision-tree try this code below:

from sklearn.tree import _tree

def tree_to_code(tree,fnames):

    tree_ = tree.tree_

    fnames = [

        fnames[n] if n != _tree.TREE_UNDEFINED else "undefined!"

        for n in tree_.feature

    ]

    print "def tree({}):".format(", ".join(fnames))

 def recurse(node, depth):

        ind = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:

            name = fnames[node]

            th = tree_.threshold[node]

            print "{}if {} <= {}:".format(ind, name, th)

            recurse(tree_.children_left[node], depth + 1)

            print "{}else:  # if {} > {}".format(ind, name, th)

            recurse(tree_.children_right[node], depth + 1)

 else:

        print "{}return {}".format(ind, tree_.value[node])

         recurse(0, 1)

The above code prints a valid Python function.

Example: output of a tree which is trying to return a number between 0 to 10

def tree(m0):

  if m0 <= 6.0:

    if m0 <= 1.5:

      return [[ 0.]]

else:  

      if m0 <= 4.5:

        if m0 <= 3.5:

          return [[ 3.]]

        else:  

          return [[ 4.]]

      else:  

        return [[ 5.]]

  else:  

    if m0 <= 8.5:

      if m0 <= 7.5:

        return [[ 7.]]

      else:  

        return [[ 8.]]

    else:  

      return [[ 9.]]

0 votes
by (41.4k points)

You can use your own function to extract the rules from the decision trees created by sklearn:

import pandas as pd

import numpy as np

from sklearn.tree import DecisionTreeClassifier

# dummy data:

df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree

dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)

dt.fit(df.ix[:,:2], df.dv)

You can use the following video tutorials to clear all your doubts:-

31k questions

32.8k answers

501 comments

693 users

Browse Categories

...