※下記、詳細は ブログを記事参照のこと

おまじない(プロットをインラインで表示するため)

In [1]:
%pylab inline
Populating the interactive namespace from numpy and matplotlib

学習データのプロット

In [2]:
import pandas as pd
import matplotlib.pyplot as plt

def x_y_species(df, species_name):
    x_array = df[df['class'] == species_name]['x'].values
    y_array = df[df['class'] == species_name]['y'].values
    return x_array, y_array

df = pd.read_csv('xor_simple.csv')
x_array, y_array = x_y_species(df, 0)
plt.plot(x_array, y_array, 'bo')

x_array, y_array = x_y_species(df, 1)
plt.plot(x_array, y_array, 'ro')
plt.xlabel('x', fontsize=20)
plt.ylabel('y', fontsize=20)
plt.axis("tight", fontsize=20)
x_min, x_max = x_array.min() - 1, x_array.max() + 1
y_min, y_max = y_array.min() - 1, y_array.max() + 1
plt.xlim((x_min, x_max))
plt.ylim((y_min, y_max))
plt.show()

教師データのロード ⇒ 学習 ⇒ 分類

In [3]:
import numpy as np
from sklearn import tree

# 教師データをロード
df = pd.read_csv('xor_simple.csv'); 
data_array = df[['x', 'y']].values
class_array = df['class'].values

# 学習(決定木)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(data_array, class_array)

#学習後に、データを与えて分類。
#与えられた教師データの特徴から考えると
# x=2.0, y=1.0 であれば、クラス「0」に分類されるはず。
# x=1.0, y= -0.5であれば、クラス「1」に分類されるはず。
result = clf.predict([[2., 1.], [1., -0.5]]) 
print "result is ... ", result
result is ...  [0 1]

決定境界の可視化

In [4]:
# Parameters for plot
n_classes = 2
plot_colors = "br"
plot_step = 0.05

#グラフ描画時の説明変数 x、yの最大値&最小値を算出。
#グラフ描画のメッシュを定義
x_min, x_max = data_array[:, 0].min() - 1, data_array[:, 0].max() + 1
y_min, y_max = data_array[:, 1].min() - 1, data_array[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                     np.arange(y_min, y_max, plot_step))

#各メッシュ上での決定木による分類を計算
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

#決定木による分類を等高線フィールドプロットでプロット
cs = plt.contour(xx, yy, Z, cmap=plt.cm.Paired)
plt.xlabel('x')
plt.ylabel('y')
plt.axis("tight")

#教師データも重ねてプロット
for i, color in zip(range(n_classes), plot_colors):
    idx = np.where(class_array == i)
    plt.scatter(data_array[idx, 0], data_array[idx, 1], c=color, label=['a','b'],
               cmap=plt.cm.Paired)
plt.axis("tight")

plt.show()

決定木の中身を表示するためのdotファイルの出力

In [5]:
from sklearn.externals.six import StringIO

with open("xor_simple.dot", 'w') as f:
    f = tree.export_graphviz(clf, out_file=f)
f.close()