A better way to visualize Decision Trees with the dtreeviz library

The notebook can contains the code for the accompanying blogpost titled A better way to visualize Decision Trees with the dtreeviz library by Parul Pandey

Installation

#conda
conda uninstall python-graphviz
conda uninstall graphviz

#pip
pip install dtreeviz             # install dtreeviz for sklearn
pip install dtreeviz[xgboost]    # install XGBoost related dependency
pip install dtreeviz[pyspark]    # install pyspark related dependency
pip install dtreeviz[lightgbm]   # install LightGBM related dependency
This should also pull in the graphviz Python library (>=0.9), which we are using for platform specific stuff.

For details see: https://github.com/parrt/dtreeviz

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn import tree
from dtreeviz.trees import *
import graphviz 

import warnings
warnings.filterwarnings("ignore") 
In [2]:
wine = pd.read_csv('winequality-red.csv')
wine.head()
Out[2]:
fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality
0 7.4 0.70 0.00 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5
1 7.8 0.88 0.00 2.6 0.098 25.0 67.0 0.9968 3.20 0.68 9.8 5
2 7.8 0.76 0.04 2.3 0.092 15.0 54.0 0.9970 3.26 0.65 9.8 5
3 11.2 0.28 0.56 1.9 0.075 17.0 60.0 0.9980 3.16 0.58 9.8 6
4 7.4 0.70 0.00 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5
In [3]:
wine['quality'].value_counts()
Out[3]:
5    681
6    638
7    199
4     53
8     18
3     10
Name: quality, dtype: int64
In [4]:
features = wine.drop('quality',axis=1)
target = wine['quality']

Regression decision tree

In [5]:
fig = plt.figure(figsize=(25,20))
regr= tree.DecisionTreeRegressor(max_depth=3)  
regr.fit(features, target)
viz = dtreeviz(regr,
               features,
               target,
               target_name='wine quality',
               feature_names=features.columns,
               title="Wine data set regression",
               fontname="Arial",
               colors = {"title":"purple"},
               scale=1.5)
viz
Out[5]:
G Wine data set regression node2 node5 leaf3 node2->leaf3 leaf4 node2->leaf4 leaf6 node5->leaf6 leaf7 node5->leaf7 node1 node1->node2 node1->node5 node8 node9 node12 leaf10 node9->leaf10 leaf11 node9->leaf11 leaf13 node12->leaf13 leaf14 node12->leaf14 node8->node9 node8->node12 node0