In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

import sys
sys.path.append('/Users/kaonpark/workspace/github.com/likejazz/kaon-learn')
import kaonlearn

from kaonlearn.plots import plot_decision_regions
In [2]:
from sklearn import svm, datasets

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features. We could
                      # avoid this ugly slicing by using a two-dim dataset
y = iris.target

h = .02  # step size in the mesh

# we create an instance of SVM and fit out data. We do not scale our
# data since we want to plot the support vectors
C = 1.0  # SVM regularization parameter
svc = svm.SVC(kernel='linear', C=C).fit(X, y)
rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y)
poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y)
lin_svc = svm.LinearSVC(C=C).fit(X, y)
In [3]:
plt.title('SVC with linear kernel')
plot_decision_regions(X, y, clf=svc, legend=0)
Out[3]:
<matplotlib.axes._subplots.AxesSubplot at 0x1122f05f8>
In [4]:
plt.title('LinearSVC (linear kernel)')
plot_decision_regions(X, y, clf=lin_svc, legend=0)
Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x1129df128>
In [5]:
plt.title('SVC with RBF kernel')
plot_decision_regions(X, y, clf=rbf_svc, legend=0)
Out[5]:
<matplotlib.axes._subplots.AxesSubplot at 0x112af0dd8>
In [6]:
plt.title('SVC with polynomial (degree 3) kernel')
plot_decision_regions(X, y, clf=poly_svc, legend=0)
Out[6]:
<matplotlib.axes._subplots.AxesSubplot at 0x112bb4eb8>
In [7]:
from sklearn.tree import DecisionTreeClassifier

tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)
Out[7]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=2,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=42,
            splitter='best')
In [8]:
import os
from tempfile import mkstemp
import subprocess

from sklearn.tree.export import export_graphviz

def convert_decision_tree_to_ipython_image(clf, feature_names=None, class_names=None, tmp_dir=None):
    dot_filename = mkstemp(suffix='.dot', dir=tmp_dir)[1]
    with open(dot_filename, "w") as out_file:
        export_graphviz(clf, out_file=out_file,
                        feature_names=feature_names,
                        class_names=class_names,
                        filled=True, rounded=True,
                        special_characters=True)

    import graphviz
    from IPython.display import display

    with open(dot_filename) as f:
        dot_graph = f.read()
    display(graphviz.Source(dot_graph))
    os.remove(dot_filename)

convert_decision_tree_to_ipython_image(tree_clf, feature_names=iris.feature_names[2:], class_names=iris.target_names)
Tree 0 petal length (cm) ≤ 5.45 gini = 0.667 samples = 150 value = [50, 50, 50] class = setosa 1 petal width (cm) ≤ 2.8 gini = 0.237 samples = 52 value = [45, 6, 1] class = setosa 0->1 True 4 petal length (cm) ≤ 6.15 gini = 0.546 samples = 98 value = [5, 44, 49] class = virginica 0->4 False 2 gini = 0.449 samples = 7 value = [1, 5, 1] class = versicolor 1->2 3 gini = 0.043 samples = 45 value = [44, 1, 0] class = setosa 1->3 5 gini = 0.508 samples = 43 value = [5, 28, 10] class = versicolor 4->5 6 gini = 0.413 samples = 55 value = [0, 16, 39] class = virginica 4->6
In [9]:
plt.title('Decision Tree max_depth=2')
plot_decision_regions(X, y, clf=tree_clf, legend=0)
Out[9]:
<matplotlib.axes._subplots.AxesSubplot at 0x112c7ef60>
In [10]:
tree2_clf = DecisionTreeClassifier(random_state=42)
tree2_clf.fit(X, y)
convert_decision_tree_to_ipython_image(tree2_clf, feature_names=iris.feature_names[2:], class_names=iris.target_names)
Tree 0 petal length (cm) ≤ 5.45 gini = 0.667 samples = 150 value = [50, 50, 50] class = setosa 1 petal width (cm) ≤ 2.8 gini = 0.237 samples = 52 value = [45, 6, 1] class = setosa 0->1 True 14 petal length (cm) ≤ 6.15 gini = 0.546 samples = 98 value = [5, 44, 49] class = virginica 0->14 False 2 petal length (cm) ≤ 4.7 gini = 0.449 samples = 7 value = [1, 5, 1] class = versicolor 1->2 9 petal length (cm) ≤ 5.35 gini = 0.043 samples = 45 value = [44, 1, 0] class = setosa 1->9 3 gini = 0.0 samples = 1 value = [1, 0, 0] class = setosa 2->3 4 petal length (cm) ≤ 4.95 gini = 0.278 samples = 6 value = [0, 5, 1] class = versicolor 2->4 5 petal width (cm) ≤ 2.45 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 4->5 8 gini = 0.0 samples = 4 value = [0, 4, 0] class = versicolor 4->8 6 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 5->6 7 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 5->7 10 gini = 0.0 samples = 39 value = [39, 0, 0] class = setosa 9->10 11 petal width (cm) ≤ 3.2 gini = 0.278 samples = 6 value = [5, 1, 0] class = setosa 9->11 12 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 11->12 13 gini = 0.0 samples = 5 value = [5, 0, 0] class = setosa 11->13 15 petal width (cm) ≤ 3.45 gini = 0.508 samples = 43 value = [5, 28, 10] class = versicolor 14->15 52 petal length (cm) ≤ 7.05 gini = 0.413 samples = 55 value = [0, 16, 39] class = virginica 14->52 16 petal length (cm) ≤ 5.75 gini = 0.388 samples = 38 value = [0, 28, 10] class = versicolor 15->16 51 gini = 0.0 samples = 5 value = [5, 0, 0] class = setosa 15->51 17 petal width (cm) ≤ 2.85 gini = 0.208 samples = 17 value = [0, 15, 2] class = versicolor 16->17 30 petal width (cm) ≤ 3.1 gini = 0.472 samples = 21 value = [0, 13, 8] class = versicolor 16->30 18 petal length (cm) ≤ 5.55 gini = 0.278 samples = 12 value = [0, 10, 2] class = versicolor 17->18 29 gini = 0.0 samples = 5 value = [0, 5, 0] class = versicolor 17->29 19 gini = 0.0 samples = 5 value = [0, 5, 0] class = versicolor 18->19 20 petal width (cm) ≤ 2.55 gini = 0.408 samples = 7 value = [0, 5, 2] class = versicolor 18->20 21 petal length (cm) ≤ 5.65 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 20->21 24 petal length (cm) ≤ 5.65 gini = 0.32 samples = 5 value = [0, 4, 1] class = versicolor 20->24 22 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 21->22 23 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 21->23 25 petal width (cm) ≤ 2.75 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 24->25 28 gini = 0.0 samples = 3 value = [0, 3, 0] class = versicolor 24->28 26 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 25->26 27 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 25->27 31 petal width (cm) ≤ 2.95 gini = 0.488 samples = 19 value = [0, 11, 8] class = versicolor 30->31 50 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 30->50 32 petal width (cm) ≤ 2.85 gini = 0.459 samples = 14 value = [0, 9, 5] class = versicolor 31->32 45 petal length (cm) ≤ 5.95 gini = 0.48 samples = 5 value = [0, 2, 3] class = virginica 31->45 33 petal length (cm) ≤ 5.9 gini = 0.486 samples = 12 value = [0, 7, 5] class = versicolor 32->33 44 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 32->44 34 petal width (cm) ≤ 2.65 gini = 0.5 samples = 6 value = [0, 3, 3] class = versicolor 33->34 39 petal width (cm) ≤ 2.65 gini = 0.444 samples = 6 value = [0, 4, 2] class = versicolor 33->39 35 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 34->35 36 petal width (cm) ≤ 2.75 gini = 0.48 samples = 5 value = [0, 2, 3] class = virginica 34->36 37 gini = 0.5 samples = 4 value = [0, 2, 2] class = versicolor 36->37 38 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 36->38 40 petal length (cm) ≤ 6.05 gini = 0.444 samples = 3 value = [0, 1, 2] class = virginica 39->40 43 gini = 0.0 samples = 3 value = [0, 3, 0] class = versicolor 39->43 41 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 40->41 42 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 40->42 46 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 45->46 47 petal length (cm) ≤ 6.05 gini = 0.444 samples = 3 value = [0, 1, 2] class = virginica 45->47 48 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 47->48 49 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 47->49 53 petal width (cm) ≤ 2.4 gini = 0.467 samples = 43 value = [0, 16, 27] class = virginica 52->53 92 gini = 0.0 samples = 12 value = [0, 0, 12] class = virginica 52->92 54 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 53->54 55 petal length (cm) ≤ 6.95 gini = 0.45 samples = 41 value = [0, 14, 27] class = virginica 53->55 56 petal width (cm) ≤ 3.15 gini = 0.439 samples = 40 value = [0, 13, 27] class = virginica 55->56 91 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 55->91 57 petal length (cm) ≤ 6.55 gini = 0.471 samples = 29 value = [0, 11, 18] class = virginica 56->57 84 petal length (cm) ≤ 6.45 gini = 0.298 samples = 11 value = [0, 2, 9] class = virginica 56->84 58 petal width (cm) ≤ 2.95 gini = 0.375 samples = 16 value = [0, 4, 12] class = virginica 57->58 71 petal length (cm) ≤ 6.65 gini = 0.497 samples = 13 value = [0, 7, 6] class = versicolor 57->71 59 petal length (cm) ≤ 6.45 gini = 0.444 samples = 12 value = [0, 4, 8] class = virginica 58->59 70 gini = 0.0 samples = 4 value = [0, 0, 4] class = virginica 58->70 60 petal width (cm) ≤ 2.85 gini = 0.397 samples = 11 value = [0, 3, 8] class = virginica 59->60 69 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 59->69 61 petal width (cm) ≤ 2.6 gini = 0.219 samples = 8 value = [0, 1, 7] class = virginica 60->61 64 petal length (cm) ≤ 6.25 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 60->64 62 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 61->62 63 gini = 0.0 samples = 6 value = [0, 0, 6] class = virginica 61->63 65 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 64->65 66 petal length (cm) ≤ 6.35 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 64->66 67 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 66->67 68 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 66->68 72 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 71->72 73 petal width (cm) ≤ 2.65 gini = 0.496 samples = 11 value = [0, 5, 6] class = virginica 71->73 74 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 73->74 75 petal width (cm) ≤ 2.9 gini = 0.5 samples = 10 value = [0, 5, 5] class = versicolor 73->75 76 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 75->76 77 petal length (cm) ≤ 6.75 gini = 0.494 samples = 9 value = [0, 4, 5] class = virginica 75->77 78 petal width (cm) ≤ 3.05 gini = 0.48 samples = 5 value = [0, 3, 2] class = versicolor 77->78 81 petal width (cm) ≤ 3.05 gini = 0.375 samples = 4 value = [0, 1, 3] class = virginica 77->81 79 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 78->79 80 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 78->80 82 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 81->82 83 gini = 0.444 samples = 3 value = [0, 1, 2] class = virginica 81->83 85 petal width (cm) ≤ 3.35 gini = 0.444 samples = 6 value = [0, 2, 4] class = virginica 84->85 90 gini = 0.0 samples = 5 value = [0, 0, 5] class = virginica 84->90 86 petal width (cm) ≤ 3.25 gini = 0.5 samples = 4 value = [0, 2, 2] class = versicolor 85->86 89 gini = 0.0 samples = 2 value = [0, 0, 2] class = virginica 85->89 87 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 86->87 88 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 86->88
In [11]:
plt.title('Decision Tree')
plot_decision_regions(X, y, clf=tree2_clf, legend=0)
Out[11]:
<matplotlib.axes._subplots.AxesSubplot at 0x10be782e8>
In [12]:
from sklearn.ensemble import RandomForestClassifier

rf_clf = RandomForestClassifier(random_state=42)
rf_clf.fit(X, y)
Out[12]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
            oob_score=False, random_state=42, verbose=0, warm_start=False)
In [13]:
plt.title('Random Forest')
plot_decision_regions(X, y, clf=rf_clf, legend=0)
Out[13]:
<matplotlib.axes._subplots.AxesSubplot at 0x1133d09e8>