%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import sys
sys.path.append('/Users/kaonpark/workspace/github.com/likejazz/kaon-learn')
import kaonlearn
from kaonlearn.plots import plot_decision_regions
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import Perceptron
iris = load_iris()
X = iris.data[:, (2, 3)] # petal length, petal width
y = (iris.target == 0).astype(np.int) # Iris Setosa?
per_clf = Perceptron(random_state=42)
per_clf.fit(X, y)
y_pred = per_clf.predict([[2, 0.5]])
y_pred
array([1])
plot_decision_regions(X, y, clf=per_clf)
<matplotlib.axes._subplots.AxesSubplot at 0x114089ef0>
X
array([[ 1.4, 0.2], [ 1.4, 0.2], [ 1.3, 0.2], [ 1.5, 0.2], [ 1.4, 0.2], [ 1.7, 0.4], [ 1.4, 0.3], [ 1.5, 0.2], [ 1.4, 0.2], [ 1.5, 0.1], [ 1.5, 0.2], [ 1.6, 0.2], [ 1.4, 0.1], [ 1.1, 0.1], [ 1.2, 0.2], [ 1.5, 0.4], [ 1.3, 0.4], [ 1.4, 0.3], [ 1.7, 0.3], [ 1.5, 0.3], [ 1.7, 0.2], [ 1.5, 0.4], [ 1. , 0.2], [ 1.7, 0.5], [ 1.9, 0.2], [ 1.6, 0.2], [ 1.6, 0.4], [ 1.5, 0.2], [ 1.4, 0.2], [ 1.6, 0.2], [ 1.6, 0.2], [ 1.5, 0.4], [ 1.5, 0.1], [ 1.4, 0.2], [ 1.5, 0.1], [ 1.2, 0.2], [ 1.3, 0.2], [ 1.5, 0.1], [ 1.3, 0.2], [ 1.5, 0.2], [ 1.3, 0.3], [ 1.3, 0.3], [ 1.3, 0.2], [ 1.6, 0.6], [ 1.9, 0.4], [ 1.4, 0.3], [ 1.6, 0.2], [ 1.4, 0.2], [ 1.5, 0.2], [ 1.4, 0.2], [ 4.7, 1.4], [ 4.5, 1.5], [ 4.9, 1.5], [ 4. , 1.3], [ 4.6, 1.5], [ 4.5, 1.3], [ 4.7, 1.6], [ 3.3, 1. ], [ 4.6, 1.3], [ 3.9, 1.4], [ 3.5, 1. ], [ 4.2, 1.5], [ 4. , 1. ], [ 4.7, 1.4], [ 3.6, 1.3], [ 4.4, 1.4], [ 4.5, 1.5], [ 4.1, 1. ], [ 4.5, 1.5], [ 3.9, 1.1], [ 4.8, 1.8], [ 4. , 1.3], [ 4.9, 1.5], [ 4.7, 1.2], [ 4.3, 1.3], [ 4.4, 1.4], [ 4.8, 1.4], [ 5. , 1.7], [ 4.5, 1.5], [ 3.5, 1. ], [ 3.8, 1.1], [ 3.7, 1. ], [ 3.9, 1.2], [ 5.1, 1.6], [ 4.5, 1.5], [ 4.5, 1.6], [ 4.7, 1.5], [ 4.4, 1.3], [ 4.1, 1.3], [ 4. , 1.3], [ 4.4, 1.2], [ 4.6, 1.4], [ 4. , 1.2], [ 3.3, 1. ], [ 4.2, 1.3], [ 4.2, 1.2], [ 4.2, 1.3], [ 4.3, 1.3], [ 3. , 1.1], [ 4.1, 1.3], [ 6. , 2.5], [ 5.1, 1.9], [ 5.9, 2.1], [ 5.6, 1.8], [ 5.8, 2.2], [ 6.6, 2.1], [ 4.5, 1.7], [ 6.3, 1.8], [ 5.8, 1.8], [ 6.1, 2.5], [ 5.1, 2. ], [ 5.3, 1.9], [ 5.5, 2.1], [ 5. , 2. ], [ 5.1, 2.4], [ 5.3, 2.3], [ 5.5, 1.8], [ 6.7, 2.2], [ 6.9, 2.3], [ 5. , 1.5], [ 5.7, 2.3], [ 4.9, 2. ], [ 6.7, 2. ], [ 4.9, 1.8], [ 5.7, 2.1], [ 6. , 1.8], [ 4.8, 1.8], [ 4.9, 1.8], [ 5.6, 2.1], [ 5.8, 1.6], [ 6.1, 1.9], [ 6.4, 2. ], [ 5.6, 2.2], [ 5.1, 1.5], [ 5.6, 1.4], [ 6.1, 2.3], [ 5.6, 2.4], [ 5.5, 1.8], [ 4.8, 1.8], [ 5.4, 2.1], [ 5.6, 2.4], [ 5.1, 2.3], [ 5.1, 1.9], [ 5.9, 2.3], [ 5.7, 2.5], [ 5.2, 2.3], [ 5. , 1.9], [ 5.2, 2. ], [ 5.4, 2.3], [ 5.1, 1.8]])
iris.data
array([[ 5.1, 3.5, 1.4, 0.2], [ 4.9, 3. , 1.4, 0.2], [ 4.7, 3.2, 1.3, 0.2], [ 4.6, 3.1, 1.5, 0.2], [ 5. , 3.6, 1.4, 0.2], [ 5.4, 3.9, 1.7, 0.4], [ 4.6, 3.4, 1.4, 0.3], [ 5. , 3.4, 1.5, 0.2], [ 4.4, 2.9, 1.4, 0.2], [ 4.9, 3.1, 1.5, 0.1], [ 5.4, 3.7, 1.5, 0.2], [ 4.8, 3.4, 1.6, 0.2], [ 4.8, 3. , 1.4, 0.1], [ 4.3, 3. , 1.1, 0.1], [ 5.8, 4. , 1.2, 0.2], [ 5.7, 4.4, 1.5, 0.4], [ 5.4, 3.9, 1.3, 0.4], [ 5.1, 3.5, 1.4, 0.3], [ 5.7, 3.8, 1.7, 0.3], [ 5.1, 3.8, 1.5, 0.3], [ 5.4, 3.4, 1.7, 0.2], [ 5.1, 3.7, 1.5, 0.4], [ 4.6, 3.6, 1. , 0.2], [ 5.1, 3.3, 1.7, 0.5], [ 4.8, 3.4, 1.9, 0.2], [ 5. , 3. , 1.6, 0.2], [ 5. , 3.4, 1.6, 0.4], [ 5.2, 3.5, 1.5, 0.2], [ 5.2, 3.4, 1.4, 0.2], [ 4.7, 3.2, 1.6, 0.2], [ 4.8, 3.1, 1.6, 0.2], [ 5.4, 3.4, 1.5, 0.4], [ 5.2, 4.1, 1.5, 0.1], [ 5.5, 4.2, 1.4, 0.2], [ 4.9, 3.1, 1.5, 0.1], [ 5. , 3.2, 1.2, 0.2], [ 5.5, 3.5, 1.3, 0.2], [ 4.9, 3.1, 1.5, 0.1], [ 4.4, 3. , 1.3, 0.2], [ 5.1, 3.4, 1.5, 0.2], [ 5. , 3.5, 1.3, 0.3], [ 4.5, 2.3, 1.3, 0.3], [ 4.4, 3.2, 1.3, 0.2], [ 5. , 3.5, 1.6, 0.6], [ 5.1, 3.8, 1.9, 0.4], [ 4.8, 3. , 1.4, 0.3], [ 5.1, 3.8, 1.6, 0.2], [ 4.6, 3.2, 1.4, 0.2], [ 5.3, 3.7, 1.5, 0.2], [ 5. , 3.3, 1.4, 0.2], [ 7. , 3.2, 4.7, 1.4], [ 6.4, 3.2, 4.5, 1.5], [ 6.9, 3.1, 4.9, 1.5], [ 5.5, 2.3, 4. , 1.3], [ 6.5, 2.8, 4.6, 1.5], [ 5.7, 2.8, 4.5, 1.3], [ 6.3, 3.3, 4.7, 1.6], [ 4.9, 2.4, 3.3, 1. ], [ 6.6, 2.9, 4.6, 1.3], [ 5.2, 2.7, 3.9, 1.4], [ 5. , 2. , 3.5, 1. ], [ 5.9, 3. , 4.2, 1.5], [ 6. , 2.2, 4. , 1. ], [ 6.1, 2.9, 4.7, 1.4], [ 5.6, 2.9, 3.6, 1.3], [ 6.7, 3.1, 4.4, 1.4], [ 5.6, 3. , 4.5, 1.5], [ 5.8, 2.7, 4.1, 1. ], [ 6.2, 2.2, 4.5, 1.5], [ 5.6, 2.5, 3.9, 1.1], [ 5.9, 3.2, 4.8, 1.8], [ 6.1, 2.8, 4. , 1.3], [ 6.3, 2.5, 4.9, 1.5], [ 6.1, 2.8, 4.7, 1.2], [ 6.4, 2.9, 4.3, 1.3], [ 6.6, 3. , 4.4, 1.4], [ 6.8, 2.8, 4.8, 1.4], [ 6.7, 3. , 5. , 1.7], [ 6. , 2.9, 4.5, 1.5], [ 5.7, 2.6, 3.5, 1. ], [ 5.5, 2.4, 3.8, 1.1], [ 5.5, 2.4, 3.7, 1. ], [ 5.8, 2.7, 3.9, 1.2], [ 6. , 2.7, 5.1, 1.6], [ 5.4, 3. , 4.5, 1.5], [ 6. , 3.4, 4.5, 1.6], [ 6.7, 3.1, 4.7, 1.5], [ 6.3, 2.3, 4.4, 1.3], [ 5.6, 3. , 4.1, 1.3], [ 5.5, 2.5, 4. , 1.3], [ 5.5, 2.6, 4.4, 1.2], [ 6.1, 3. , 4.6, 1.4], [ 5.8, 2.6, 4. , 1.2], [ 5. , 2.3, 3.3, 1. ], [ 5.6, 2.7, 4.2, 1.3], [ 5.7, 3. , 4.2, 1.2], [ 5.7, 2.9, 4.2, 1.3], [ 6.2, 2.9, 4.3, 1.3], [ 5.1, 2.5, 3. , 1.1], [ 5.7, 2.8, 4.1, 1.3], [ 6.3, 3.3, 6. , 2.5], [ 5.8, 2.7, 5.1, 1.9], [ 7.1, 3. , 5.9, 2.1], [ 6.3, 2.9, 5.6, 1.8], [ 6.5, 3. , 5.8, 2.2], [ 7.6, 3. , 6.6, 2.1], [ 4.9, 2.5, 4.5, 1.7], [ 7.3, 2.9, 6.3, 1.8], [ 6.7, 2.5, 5.8, 1.8], [ 7.2, 3.6, 6.1, 2.5], [ 6.5, 3.2, 5.1, 2. ], [ 6.4, 2.7, 5.3, 1.9], [ 6.8, 3. , 5.5, 2.1], [ 5.7, 2.5, 5. , 2. ], [ 5.8, 2.8, 5.1, 2.4], [ 6.4, 3.2, 5.3, 2.3], [ 6.5, 3. , 5.5, 1.8], [ 7.7, 3.8, 6.7, 2.2], [ 7.7, 2.6, 6.9, 2.3], [ 6. , 2.2, 5. , 1.5], [ 6.9, 3.2, 5.7, 2.3], [ 5.6, 2.8, 4.9, 2. ], [ 7.7, 2.8, 6.7, 2. ], [ 6.3, 2.7, 4.9, 1.8], [ 6.7, 3.3, 5.7, 2.1], [ 7.2, 3.2, 6. , 1.8], [ 6.2, 2.8, 4.8, 1.8], [ 6.1, 3. , 4.9, 1.8], [ 6.4, 2.8, 5.6, 2.1], [ 7.2, 3. , 5.8, 1.6], [ 7.4, 2.8, 6.1, 1.9], [ 7.9, 3.8, 6.4, 2. ], [ 6.4, 2.8, 5.6, 2.2], [ 6.3, 2.8, 5.1, 1.5], [ 6.1, 2.6, 5.6, 1.4], [ 7.7, 3. , 6.1, 2.3], [ 6.3, 3.4, 5.6, 2.4], [ 6.4, 3.1, 5.5, 1.8], [ 6. , 3. , 4.8, 1.8], [ 6.9, 3.1, 5.4, 2.1], [ 6.7, 3.1, 5.6, 2.4], [ 6.9, 3.1, 5.1, 2.3], [ 5.8, 2.7, 5.1, 1.9], [ 6.8, 3.2, 5.9, 2.3], [ 6.7, 3.3, 5.7, 2.5], [ 6.7, 3. , 5.2, 2.3], [ 6.3, 2.5, 5. , 1.9], [ 6.5, 3. , 5.2, 2. ], [ 6.2, 3.4, 5.4, 2.3], [ 5.9, 3. , 5.1, 1.8]])
from sklearn.neural_network import MLPClassifier
mlp = MLPClassifier(solver='adam', random_state=42, hidden_layer_sizes=[100, 1], max_iter=2000, activation='logistic').fit(X,y)
plot_decision_regions(X, y, clf=mlp)
<matplotlib.axes._subplots.AxesSubplot at 0x114421710>
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
# 100 * 1
model.add(Dense(100, input_dim=2, activation='sigmoid'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
Using TensorFlow backend.
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 100) 300 _________________________________________________________________ dense_2 (Dense) (None, 1) 101 ================================================================= Total params: 401.0 Trainable params: 401 Non-trainable params: 0.0 _________________________________________________________________
from keras_tqdm import TQDMNotebookCallback
model.fit(X, y, epochs=2000, verbose=0, callbacks=[TQDMNotebookCallback(show_inner=False)])
<keras.callbacks.History at 0x11e18cda0>
plot_decision_regions(X, y, clf=model)
<matplotlib.axes._subplots.AxesSubplot at 0x11e2d9748>