Git Product home page Git Product logo

dtreeviz's Introduction

dtreeviz : Decision Tree Visualization

Description

A python library for decision tree visualization and model interpretation.

By Terence Parr and Prince Grover

See How to visualize decision trees for deeper discussion of our decision tree visualization library and the visual design decisions we made.

Discussion

Decision trees are the fundamental building block of gradient boosting machines and Random Forests(tm), probably the two most popular machine learning models for structured data. Visualizing decision trees is a tremendous aid when learning how these models work and when interpreting models. Unfortunately, current visualization packages are rudimentary and not immediately helpful to the novice. For example, we couldn't find a library that visualizes how decision nodes split up the feature space. It is also uncommon for libraries to support visualizing a specific feature vector as it weaves down through a tree's decision nodes; we could only find one image showing this.

So, we've created a general package for scikit-learn decision tree visualization and model interpretation, which we'll be using heavily in an upcoming machine learning book (written with Jeremy Howard).

The visualizations are inspired by an educational animation by R2D3; A visual introduction to machine learning. With dtreeviz, you can visualize how the feature space is split up at decision nodes, how the training samples get distributed in leaf nodes and how the tree makes predictions for a specific observation. These operations are critical to for understanding how classification or regression decision trees work. If you're not familiar with decision trees, check out fast.ai's Introduction to Machine Learning for Coders MOOC.

Install

Install anaconda3 on your system, if not already done.

You might verify that you do not have conda-installed graphviz-related packages installed because dtreeviz needs the pip versions; you can remove them from conda space by doing:

conda uninstall python-graphviz
conda uninstall graphviz

To install (Python >=3.6 only), do this (from Anaconda Prompt on Windows!):

pip install dtreeviz

This should also pull in the graphviz Python library (>=0.9), which we are using for platform specific stuff.

Please email Terence with any helpful notes on making dtreeviz work (better) on other platforms. Thanks!

For your specific platform, please see the following subsections.

Mac

The brew spec for graphviz recently removed the options we use to use, --with-librsvg --with-pango. Now you have to build graphviz from source to get this working. (Feb 4, 2018)

Make sure to have the latest XCode installed and command-line tools installed. You can run xcode-select --install from the command-line to install those if XCode is already installed. You also have to sign the XCode license agreement, which you can do with sudo xcodebuild -license from command-line. The brew install shown next needs to build graphviz, so you need XCode set up properly.

You need the graphviz binary for dot installed with librsvg and pango. Make sure you follow this procedure (verified on three recent macs with 10.13) to build graphviz 2.40.1 from source and install:

brew uninstall graphviz
brew reinstall pango librsvg  # even if already there, please reinstall

cd /tmp
wget https://graphviz.gitlab.io/pub/graphviz/stable/SOURCES/graphviz.tar.gz
tar xvfz graphviz.tar.gz
cd graphviz-2.40.1/

rm -rf /usr/local/lib/graphviz # in case old stuff is there
# make sure to run ./configure after reinstalling pango, librsvg
./configure --includedir=/usr/local/include/graphviz --with-pangocairo=yes
make -j 8 # 8 threads
make install

Just to be sure, remove dot from any anaconda installation, for example:

rm ~/anaconda3/bin/dot

From command line, this command

dot -Tsvg:cairo

should work, in the sense that it just stares at you without giving an error. You can hit control-C to escape back to the shell. If you still get an error message that says something about trying -Tsvg:svg:core option, then we still have a problem. Make sure that you are using the right dot:

$ which dot
/usr/local/bin/dot
$ ls -l $(which dot)
-rwxr-xr-x  1 parrt  wheel  22920 Feb  5 09:02 /usr/local/bin/dot*
$

If instead, it says something like

$ ls -l $(which dot)
lrwxr-xr-x  1 parrt  wheel  33 Feb  4 19:54 /usr/local/bin/dot@ -> ../Cellar/graphviz/2.40.1/bin/dot

then you're still using the brew version. Do a brew uninstall graphviz and then do make install again from the /tmp/graphviz-2.40.1 dir.

The OS X version is able to generate/save images in any format dot is allowed to use with the -T{format}:cairo option. So .svg, .pdf are totally safe bets.

Limitations. Jupyter notebook has a bug where they do not show .svg files correctly, but Juypter Lab has no problem.

Linux (Ubuntu 18.04)

To get the dot binary do:

sudo apt install graphviz

Limitations. The view() method works to pop up a new window and images appear inline for jupyter notebook but not jupyter lab (It gets an error parsing the SVG XML.) The notebook images also have a font substitution from the Arial we use and so some text overlaps. Only .svg files can be generated on this platform.

Windows 10

Download graphviz-2.38.msi and update your Path environment variable. It's windows so you might need a reboot after updating that environment variable. You should see this from the Anaconda Prompt:

(base) C:\Users\Terence Parr>where dot
C:\Program Files (x86)\Graphviz2.38\bin\dot.exe

(Do not use conda install -c conda-forge python-graphviz as you get an old version of graphviz python library.)

Verify from the Anaconda Prompt that this works (capital -V not lowercase -v):

dot -V

If it doesn't work, you have a Path problem. I found the following test programs useful. The first one sees if Python can find dot:

import os
import subprocess
proc = subprocess.Popen(['dot','-V'])
print( os.getenv('Path') )

The following version does the same thing except uses graphviz Python libraries backend support utilities, which is what we use in dtreeviz:

import graphviz.backend as be
cmd = ["dot", "-V"]
stdout, stderr = be.run(cmd, capture_output=True, check=True, quiet=False)
print( stderr )

Jupyter Lab and Jupyter notebook both show the inline .svg images well.

Verify graphviz installation

Try making text file t.dot with content digraph T { A -> B } (paste that into a text editor, for example) and then running this from the command line:

dot -Tsvg:cairo -o t.svg t.dot

That should give a simple t.svg file that opens properly. If you get errors from dot, it will not work from the dtreeviz python code. If it can't find dot then you didn't update your PATH environment variable or there is some other install issue with graphviz.

Limitations

Finally, don't use IE to view .svg files. Use Edge as they look much better. I suspect that IE is displaying them as a rasterized not vector images. Only .svg files can be generated on this platform.

Usage

dtree: Main function to create decision tree visualization. Given a decision tree regressor or classifier, creates and returns a tree visualization using the graphviz (DOT) language.

Required libraries

Basic libraries and imports that will (might) be needed to generate the sample visualizations shown in examples below.

from sklearn.datasets import *
from sklearn import tree
from dtreeviz.trees import *

Regression decision tree

The default orientation of tree is top down but you can change it to left to right using orientation="LR". view() gives a pop up window with rendered graphviz object.

regr = tree.DecisionTreeRegressor(max_depth=2)
boston = load_boston()
regr.fit(boston.data, boston.target)

viz = dtreeviz(regr,
               boston.data,
               boston.target,
               target_name='price',
               feature_names=boston.feature_names)
              
viz.view()              

Classification decision tree

An additional argument of class_names giving a mapping of class value with class name is required for classification trees.

classifier = tree.DecisionTreeClassifier(max_depth=2)  # limit depth of tree
iris = load_iris()
classifier.fit(iris.data, iris.target)

viz = dtreeviz(classifier, 
               iris.data, 
               iris.target,
               target_name='variety',
              feature_names=iris.feature_names, 
               class_names=["setosa", "versicolor", "virginica"]  # need class_names for classifier
              )  
              
viz.view() 

Prediction path

Highlights the decision nodes in which the feature value of single observation passed in argument X falls. Gives feature values of the observation and highlights features which are used by tree to traverse path.

regr = tree.DecisionTreeRegressor(max_depth=2)  # limit depth of tree
diabetes = load_diabetes()
regr.fit(diabetes.data, diabetes.target)
X = diabetes.data[np.random.randint(0, len(diabetes.data)),:]  # random sample from training

viz = dtreeviz(regr,
               diabetes.data, 
               diabetes.target, 
               target_name='value', 
               orientation ='LR',  # left-right orientation
               feature_names=diabetes.feature_names,
               X=X)  # need to give single observation for prediction
              
viz.view()  

Decision tree without scatterplot or histograms for decision nodes

Simple tree without histograms or scatterplots for decision nodes. Use argument fancy=False

classifier = tree.DecisionTreeClassifier(max_depth=4)  # limit depth of tree
cancer = load_breast_cancer()
classifier.fit(cancer.data, cancer.target)

viz = dtreeviz(classifier,
              cancer.data,
              cancer.target,
              target_name='cancer',
              feature_names=cancer.feature_names, 
              class_names=["malignant", "benign"],
              fancy=False )  # fance=False to remove histograms/scatterplots from decision nodes
              
viz.view() 

For more examples and different implementations, please see the jupyter notebook full of examples.

Regression univariate feature-target space

import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from dtreeviz.trees import *

df_cars = pd.read_csv("data/cars.csv")
X_train, y_train = df_cars.drop('MPG', axis=1), df_cars['MPG']

fig = plt.figure()
ax = fig.gca()
t = rtreeviz_univar(ax,
                    X_train.WGT, y_train,
                    max_depth=2,
                    feature_name='Vehicle Weight',
                    target_name='MPG',
                    fontsize=14)
plt.show()

Regression bivariate feature-target space

from mpl_toolkits.mplot3d import Axes3D
from dtreeviz.trees import *

df_cars = pd.read_csv("data/cars.csv")
X = df_cars.drop('MPG', axis=1)
y = df_cars['MPG']

features = [2, 1]
X = X.values[:,features]
figsize = (6,5)
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')

t = rtreeviz_bivar_3D(ax,
                      X, y,
                      max_depth=4,
                      feature_names=['Vehicle Weight', 'Horse Power'],
                      target_name='MPG',
                      fontsize=14,
                      elev=20,
                      azim=25,
                      dist=8.2,
                      show={'splits','title'})
plt.show()

Regression bivariate feature-target space heatmap

from dtreeviz.trees import *

df_cars = pd.read_csv("data/cars.csv")
X = df_cars.drop('MPG', axis=1)
y = df_cars['MPG']

features=[2, 1]
X = X.values[:, features]
figsize = (6, 5)
fig, ax = plt.subplots(1, 1, figsize=figsize)
t = rtreeviz_bivar_heatmap(ax,
                           X, y,
                           max_depth=4,
                           feature_names=['Vehicle Weight', 'Horse Power'],
                           fontsize=14)
plt.show()

Classification univariate feature-target space

from dtreeviz.trees import *

know = pd.read_csv("data/knowledge.csv")
class_names = ['very_low', 'Low', 'Middle', 'High']
know['UNS'] = know['UNS'].map({n: i for i, n in enumerate(class_names)})

x_train = know.PEG
y_train = know['UNS']
figsize = (6,2)
fig, ax = plt.subplots(1, 1, figsize=figsize)
ct = ctreeviz_univar(ax, x_train, y_train, max_depth=3,
                     feature_name = 'PEG', class_names=class_names,
                     target_name='Knowledge',
                     nbins=40, gtype='strip',
                     show={'splits','title'})
plt.tight_layout()
plt.show()

Classification bivariate feature-target space

from dtreeviz.trees import *

know = pd.read_csv("data/knowledge.csv")
class_names = ['very_low', 'Low', 'Middle', 'High']
know['UNS'] = know['UNS'].map({n: i for i, n in enumerate(class_names)})

features=[4,3]
X_train = know.drop('UNS', axis=1)
y_train = know['UNS']
X_train = X_train.values[:, features]
figsize = (6,5)
fig, ax = plt.subplots(1, 1, figsize=figsize)
ctreeviz_bivar(ax, X_train, y_train, max_depth=3,
               feature_names = ['PEG','LPR'],
               class_names=class_names,
               target_name='Knowledge')
plt.tight_layout()
plt.show()

Install dtreeviz locally

Make sure to follow the install guidelines above.

To push the dtreeviz library to your local egg cache (force updates) during development, do this (from anaconda prompt on Windows):

python setup.py install -f

E.g., on Terence's box, it add /Users/parrt/anaconda3/lib/python3.6/site-packages/dtreeviz-0.3-py3.6.egg.

Useful Resources

Authors

See also the list of contributors who participated in this project.

License

This project is licensed under the terms of the MIT license, see LICENSE.

dtreeviz's People

Contributors

parrt avatar praneet460 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.