from sklearn import datasets
digits = datasets.load_digits()
import pandas as pd
pd.DataFrame(digits.data).head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 0.0 | 5.0 | 13.0 | 9.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 13.0 | 10.0 | 0.0 | 0.0 | 0.0 |
1 | 0.0 | 0.0 | 0.0 | 12.0 | 13.0 | 5.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 11.0 | 16.0 | 10.0 | 0.0 | 0.0 |
2 | 0.0 | 0.0 | 0.0 | 4.0 | 15.0 | 12.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 5.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 11.0 | 16.0 | 9.0 | 0.0 |
3 | 0.0 | 0.0 | 7.0 | 15.0 | 13.0 | 1.0 | 0.0 | 0.0 | 0.0 | 8.0 | ... | 9.0 | 0.0 | 0.0 | 0.0 | 7.0 | 13.0 | 13.0 | 9.0 | 0.0 | 0.0 |
4 | 0.0 | 0.0 | 0.0 | 1.0 | 11.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 16.0 | 4.0 | 0.0 | 0.0 |
5 rows × 64 columns
from sklearn import cross_validation as cv
x_train, x_test, t_train, t_test = cv.train_test_split(digits.data, digits.target, test_size=0.2)
/Users/kot/miniconda2/envs/python3/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20. "This module will be removed in 0.20.", DeprecationWarning)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
(1437, 64) (1437,) (360, 64) (360,)
import numpy as np
def scale(X, eps = 0.001):
# scale the data points s.t the columns of the feature space
# (i.e the predictors) are within the range [0, 1]
return (X - np.min(X, axis = 0)) / (np.max(X, axis = 0) + eps)
x_train2 = scale(x_train)
x_test2 = scale(x_test)
t_train2 = scale(t_train)
t_test2 = scale(t_test)
%matplotlib inline
/Users/kot/miniconda2/envs/python3/lib/python3.6/site-packages/matplotlib/font_manager.py:280: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment. 'Matplotlib is building the font cache using fc-list. '
from PIL import Image
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
img = x_train2[0]
img = img.reshape(8,8)
img_show(img)
img.shapeim
(64,)
import matplotlib.pyplot as plt
plt.imshow(x_train[2].reshape((8, 8)))
<matplotlib.image.AxesImage at 0x11712a350>
x_train2.shape
(1437, 64)
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
__file__ = "."
# coding: utf-8
#try:
# import urllib.request
#except ImportError:
# raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np
url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"
train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784
def _download(file_name):
file_path = dataset_dir + "/" + file_name
if os.path.exists(file_path):
return
print("Downloading " + file_name + " ... ")
urllib.request.urlretrieve(url_base + file_name, file_path)
print("Done")
def download_mnist():
for v in key_file.values():
_download(v)
def _load_label(file_name):
file_path = dataset_dir + "/" + file_name
print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)
print("Done")
return labels
def _load_img(file_name):
file_path = dataset_dir + "/" + file_name
print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, img_size)
print("Done")
return data
def _convert_numpy():
dataset = {}
dataset['train_img'] = _load_img(key_file['train_img'])
dataset['train_label'] = _load_label(key_file['train_label'])
dataset['test_img'] = _load_img(key_file['test_img'])
dataset['test_label'] = _load_label(key_file['test_label'])
return dataset
def init_mnist():
download_mnist()
dataset = _convert_numpy()
print("Creating pickle file ...")
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1)
print("Done!")
def _change_ont_hot_label(X):
T = np.zeros((X.size, 10))
for idx, row in enumerate(T):
row[X[idx]] = 1
return T
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""MNISTデータセットの読み込み
Parameters
----------
normalize : 画像のピクセル値を0.0~1.0に正規化する
one_hot_label :
one_hot_labelがTrueの場合、ラベルはone-hot配列として返す
one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列
flatten : 画像を一次元配列に平にするかどうか
Returns
-------
(訓練画像, 訓練ラベル), (テスト画像, テストラベル)
"""
if not os.path.exists(save_file):
init_mnist()
with open(save_file, 'rb') as f:
dataset = pickle.load(f)
if normalize:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0
if one_hot_label:
dataset['train_label'] = _change_ont_hot_label(dataset['train_label'])
dataset['test_label'] = _change_ont_hot_label(dataset['test_label'])
if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
if __name__ == '__main__':
init_mnist()
Converting train-images-idx3-ubyte.gz to NumPy Array ... Done Converting train-labels-idx1-ubyte.gz to NumPy Array ... Done Converting t10k-images-idx3-ubyte.gz to NumPy Array ... Done Converting t10k-labels-idx1-ubyte.gz to NumPy Array ... Done Creating pickle file ... Done!
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=True)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
(60000, 784) (60000,) (10000, 784) (10000,)
from PIL import Image
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape) # (784,)
img = img.reshape(28, 28) # 形状を元の画像サイズに変形
print(img.shape) # (28, 28)
img_show(img)
5 (784,) (28, 28)
# coding: utf-8
import numpy as np
def identity_function(x):
return x
def step_function(x):
return np.array(x > 0, dtype=np.int)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def sigmoid_grad(x):
return (1.0 - sigmoid(x)) * sigmoid(x)
def relu(x):
return np.maximum(0, x)
def relu_grad(x):
grad = np.zeros(x)
grad[x>=0] = 1
return grad
def softmax(x):
if x.ndim == 2:
x = x.T
x = x - np.max(x, axis=0)
y = np.exp(x) / np.sum(np.exp(x), axis=0)
return y.T
x = x - np.max(x) # オーバーフロー対策
return np.exp(x) / np.sum(np.exp(x))
def mean_squared_error(y, t):
return 0.5 * np.sum((y-t)**2)
def cross_entropy_error(y, t):
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
# 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
if t.size == y.size:
t = t.argmax(axis=1)
batch_size = y.shape[0]
return -np.sum(np.log(y[np.arange(batch_size), t])) / batch_size
def softmax_loss(X, t):
y = softmax(X)
return cross_entropy_error(y, t)
import numpy as np
import pickle
#from common.functions import sigmoid, softmax
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
if p == t[i]:
accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
Accuracy:0.9352
x, _ = get_data()
network = init_network()
W1, W2, W3 = network['W1'], network['W2'], network['W3']
x.shape
(10000, 784)
x[0].shape
(784,)
W1.shape
(784, 50)
W2.shape
(50, 100)
W3.shape
(100, 10)
x, t = get_data()
network = init_network()
batch_size = 100
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
print ("Accuracy:" + str(float(accuracy_cnt) / len(x)))
Accuracy:0.9352
x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
print (np.argmax(x, axis=1))
[1 2 1 0]
y = np.array([1, 2, 1, 0])
t = np.array([1, 2, 0, 0])
print (y == t)
[ True True False True]
np.sum(y==t)
3
(x_train, t_train), (x_test, t_test)= load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
(60000, 784) (60000, 10)
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
batch_mask
array([14262, 57152, 12582, 53075, 17635, 26045, 43431, 58703, 29177, 48331])
#def cross_entropy_error(y, t):
# if y.ndim == 1:
# t = t.reshape(1, t.size)
# y = y.reshape(1, y.size)
# batch_size = y.shape[0]
# return -np.sum(t * np.log(y)) / batch_size
def gradient_descent(f, init_x, lr=0.01, step_num=100):
x = init_x
for i in range(step_num):
grad = numerical_gradient(f, x)
x -= lr * grad
return x
def numerical_gradient(f, x):
h = 1e-4
grad = np.zeros_like(x)
for idx in range(x.size):
tmp_val = x[idx]
x[idx] = tmp_val + h
fxh1 = f(x)
x[idx] = tmp_val - h
fxh2 = f(x)
grad[idx] = (fxh1 - fxh2) / (2*h)
x[idx] = tmp_val
return grad
def function_2(x):
return x[0]**2 + x[1]**2
init_x = np.array([-3.0, 4.0])
gradient_descent(function_2, init_x=init_x, lr=0.1, step_num=100)
array([ -6.11110793e-10, 8.14814391e-10])
init_x = np.array([-3.0, 4.0])
gradient_descent(function_2, init_x=init_x, lr=10, step_num=100)
array([ -2.58983747e+13, -1.29524862e+12])
init_x = np.array([-3.0, 4.0])
gradient_descent(function_2, init_x=init_x, lr=1e-10, step_num=100)
array([-2.99999994, 3.99999992])
class simpleNet:
def __init__(self):
self.W = np.random.randn(2,3)
def predict(self, x):
return np.dot(x, self.W)
def loss(self, x, t):
z = self.predict(x)
y = softmax(z)
loss = cross_entropy_error(y, t)
return loss
net = simpleNet()
print(net.W)
[[-0.58865965 2.01375848 1.74139982] [-0.91928797 0.07140906 -1.07484407]]
x = np.array([0.6, 0.9])
p = net.predict(x)
print(p)
[-1.18055496 1.27252325 0.07748023]
np.argmax(p)
1
t = np.array([0, 0, 1])
net.loss(x, t)
1.5234249725309608
def f(W):
return net.loss(x, t)
#dW = #numerical_gradient(f, x)
#print(dW)
[ 0.05177969 0.83503943]
def numerical_gradient(f, X):
if X.ndim == 1:
return _numerical_gradient_no_batch(f, X)
else:
grad = np.zeros_like(X)
for idx, x in enumerate(X):
grad[idx] = _numerical_gradient_no_batch(f, x)
return grad
def _numerical_gradient_no_batch(f, x):
h = 1e-4 # 0.0001
grad = np.zeros_like(x)
for idx in range(x.size):
tmp_val = x[idx]
x[idx] = float(tmp_val) + h
fxh1 = f(x) # f(x+h)
x[idx] = tmp_val - h
fxh2 = f(x) # f(x-h)
grad[idx] = (fxh1 - fxh2) / (2*h)
x[idx] = tmp_val # 値を元に戻す
return grad
dW = numerical_gradient(f, net.W)
print(dW)
[[ 0.03716879 0.43205276 -0.46922155] [ 0.05575319 0.64807913 -0.70383232]]
f = lambda w: net.loss(x, t)
dW = numerical_gradient(f, net.W)
dW
array([[ 0.03716879, 0.43205276, -0.46922155], [ 0.05575319, 0.64807913, -0.70383232]])
class TwoLayerNet:
def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
# 重みの初期化
self.params = {}
self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
self.params['b1'] = np.zeros(hidden_size)
self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
self.params['b2'] = np.zeros(output_size)
def predict(self, x):
W1, W2 = self.params['W1'], self.params['W2']
b1, b2 = self.params['b1'], self.params['b2']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
y = softmax(a2)
return y
# x:入力データ, t:教師データ
def loss(self, x, t):
y = self.predict(x)
return cross_entropy_error(y, t)
def accuracy(self, x, t):
y = self.predict(x)
y = np.argmax(y, axis=1)
t = np.argmax(t, axis=1)
accuracy = np.sum(y == t) / float(x.shape[0])
return accuracy
# x:入力データ, t:教師データ
def numerical_gradient(self, x, t):
loss_W = lambda W: self.loss(x, t)
grads = {}
grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
return grads
def gradient(self, x, t):
W1, W2 = self.params['W1'], self.params['W2']
b1, b2 = self.params['b1'], self.params['b2']
grads = {}
batch_num = x.shape[0]
# forward
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
y = softmax(a2)
# backward
dy = (y - t) / batch_num
grads['W2'] = np.dot(z1.T, dy)
grads['b2'] = np.sum(dy, axis=0)
da1 = np.dot(dy, W2.T)
dz1 = sigmoid_grad(a1) * da1
grads['W1'] = np.dot(x.T, dz1)
grads['b1'] = np.sum(dz1, axis=0)
return grads
net = TwoLayerNet(input_size=784, hidden_size=100, output_size=10)
print(net.params['W1'].shape)
print(net.params['b1'].shape)
print(net.params['W2'].shape)
print(net.params['b2'].shape)
(784, 100) (100,) (100, 10) (10,)
x = np.random.rand(100, 784)
y = net.predict(x)
x = np.random.rand(100, 784)
t = np.random.rand(100, 10)
!date
grads = net.numerical_gradient(x, t)
!date
Wed Feb 1 15:40:30 JST 2017 Wed Feb 1 15:43:16 JST 2017
grads
{'W1': array([[ 2.67369611e-04, -7.75889930e-05, 6.15602689e-04, ..., 1.73279897e-04, 4.78752291e-04, 6.94256137e-04], [ 3.84986976e-04, -1.69142540e-04, 3.41504851e-04, ..., 2.58689967e-04, 2.58881654e-04, 2.32389707e-04], [ 4.54322020e-04, -8.04876166e-05, 4.28592681e-04, ..., 2.46636551e-04, 4.09743581e-04, 3.52666802e-04], ..., [ 4.59333496e-04, -1.17816850e-04, 4.81431774e-04, ..., 2.26373862e-04, 5.49254056e-04, 4.90443135e-04], [ 2.42431257e-04, 7.76675613e-06, 5.13711722e-04, ..., 1.73029049e-04, 4.04340268e-04, 3.92859683e-04], [ 3.89662729e-04, -2.30543939e-05, 4.59559706e-04, ..., 2.18205320e-04, 3.01693817e-04, 4.77104054e-04]]), 'W2': array([[-0.03198304, 0.02717815, -0.02981612, 0.00681504, 0.0089925 , -0.005349 , -0.01642694, -0.03256691, 0.04997319, 0.02318311], [-0.03119389, 0.02721471, -0.03085734, 0.00642146, 0.00986938, -0.00491048, -0.01702456, -0.03054009, 0.04743485, 0.02358596], [-0.03267288, 0.0283553 , -0.03261903, 0.0062267 , 0.01058678, -0.00559903, -0.01814501, -0.03167549, 0.05122607, 0.02431658], [-0.03357068, 0.02871875, -0.03182414, 0.00640361, 0.0103219 , -0.00616208, -0.01785473, -0.03269834, 0.05243036, 0.02423535], [-0.03325657, 0.02865871, -0.03301966, 0.00677389, 0.00939585, -0.00549831, -0.0179785 , -0.0309816 , 0.05135571, 0.02455047], [-0.03232786, 0.02749997, -0.03036128, 0.00619401, 0.01019825, -0.00549758, -0.01586997, -0.03067399, 0.04855167, 0.0222868 ], [-0.02967312, 0.02622947, -0.02904659, 0.00599259, 0.00847699, -0.00524407, -0.01628509, -0.02841193, 0.04555139, 0.02241035], [-0.03073476, 0.02635987, -0.02842785, 0.00611274, 0.00897285, -0.00521932, -0.01578096, -0.03055737, 0.04699068, 0.02228413], [-0.03259403, 0.02794242, -0.03159519, 0.00640731, 0.0100081 , -0.00553781, -0.01668065, -0.03168935, 0.05014844, 0.02359077], [-0.03252071, 0.02855936, -0.03150123, 0.00638151, 0.00995984, -0.00626738, -0.01823107, -0.03192518, 0.05142895, 0.02411591], [-0.03294085, 0.02818082, -0.03252785, 0.00563558, 0.00974453, -0.00514171, -0.0177371 , -0.03307581, 0.05261977, 0.02524261], [-0.03217194, 0.02794838, -0.03126961, 0.0063377 , 0.00959314, -0.00601149, -0.01754752, -0.03098612, 0.0499801 , 0.02412735], [-0.030014 , 0.02467216, -0.02836801, 0.00557356, 0.01004189, -0.00572356, -0.01467843, -0.02894394, 0.04604153, 0.0213988 ], [-0.03367155, 0.02895995, -0.03336724, 0.00682739, 0.00988598, -0.00625143, -0.01778958, -0.03257559, 0.05276805, 0.02521403], [-0.0309256 , 0.02595567, -0.03020455, 0.00752918, 0.0084502 , -0.00503659, -0.01602513, -0.02924623, 0.046955 , 0.02254805], [-0.03244759, 0.028182 , -0.03089714, 0.00639037, 0.01003739, -0.00567541, -0.01822319, -0.03269369, 0.05126547, 0.02406179], [-0.0356662 , 0.03022932, -0.03444256, 0.00682953, 0.01051575, -0.00612012, -0.01910726, -0.03527303, 0.05681367, 0.02622091], [-0.02852678, 0.0236224 , -0.02712294, 0.00473471, 0.00919544, -0.00421963, -0.01460191, -0.0276146 , 0.04331715, 0.02121616], [-0.03022471, 0.02561573, -0.0298089 , 0.00551717, 0.00949658, -0.00563184, -0.01558725, -0.02807421, 0.04703132, 0.02166612], [-0.03267307, 0.02831217, -0.03163069, 0.00636957, 0.00979272, -0.00554511, -0.01640117, -0.03158512, 0.04999558, 0.02336512], [-0.03301293, 0.02832135, -0.03233697, 0.00659926, 0.00998891, -0.00725532, -0.01642633, -0.03249358, 0.05155703, 0.02505858], [-0.03126952, 0.02665345, -0.02881072, 0.00606286, 0.00903231, -0.00539506, -0.01620834, -0.02948131, 0.04793481, 0.02148154], [-0.0334218 , 0.02766113, -0.03181198, 0.00691299, 0.00920411, -0.0050185 , -0.01749017, -0.03315036, 0.05145515, 0.02565941], [-0.03244443, 0.02666704, -0.03239974, 0.0064261 , 0.01042061, -0.00478653, -0.0175079 , -0.03148418, 0.05127108, 0.02383793], [-0.03558277, 0.02902961, -0.03426098, 0.00722922, 0.0118612 , -0.00580817, -0.01842815, -0.03536799, 0.05487755, 0.02645048], [-0.03200738, 0.02730944, -0.02956318, 0.00587457, 0.01060228, -0.00631517, -0.01598442, -0.03029297, 0.04731106, 0.02306577], [-0.02980012, 0.0255988 , -0.02879604, 0.00645819, 0.00764886, -0.00517377, -0.01669975, -0.02993129, 0.0481814 , 0.02251371], [-0.03227284, 0.02774468, -0.03127675, 0.00660656, 0.00887454, -0.0052724 , -0.0169372 , -0.03123511, 0.05032571, 0.0234428 ], [-0.03624249, 0.03032377, -0.03465814, 0.00684729, 0.01152002, -0.00582574, -0.01861973, -0.03508448, 0.05568695, 0.02605255], [-0.03057005, 0.02476645, -0.02885317, 0.00647178, 0.00951171, -0.00546008, -0.01596552, -0.02918284, 0.04638251, 0.02289921], [-0.03172336, 0.02660208, -0.02970977, 0.006237 , 0.00919896, -0.00521963, -0.01674724, -0.02952653, 0.04844093, 0.02244757], [-0.03504131, 0.02867918, -0.03275866, 0.00704247, 0.0103246 , -0.00642769, -0.0178196 , -0.03226525, 0.05323468, 0.02503158], [-0.02916934, 0.02394805, -0.02813669, 0.0053522 , 0.00931758, -0.00445221, -0.01530284, -0.02769686, 0.04516583, 0.02097428], [-0.03464164, 0.030717 , -0.03375417, 0.00597023, 0.01072178, -0.00590432, -0.01898756, -0.03413017, 0.05381267, 0.02619618], [-0.03486024, 0.03060629, -0.03432934, 0.00639175, 0.01070439, -0.00616784, -0.01852467, -0.03422899, 0.05469388, 0.02571477], [-0.03538614, 0.0299507 , -0.03270079, 0.00675916, 0.01067077, -0.00548131, -0.01843221, -0.0335557 , 0.05278693, 0.02538859], [-0.03230838, 0.02705587, -0.03136359, 0.00751807, 0.00955634, -0.00522636, -0.01755897, -0.0331713 , 0.05170664, 0.02379167], [-0.03326106, 0.02856763, -0.03158246, 0.00631648, 0.01069272, -0.00629406, -0.0171309 , -0.03329127, 0.05085762, 0.02512528], [-0.03239345, 0.02723806, -0.03114829, 0.00625775, 0.0091459 , -0.00474821, -0.01626641, -0.03131074, 0.04934919, 0.0238762 ], [-0.03368329, 0.02988329, -0.03245127, 0.00612585, 0.00994288, -0.00637925, -0.01772364, -0.03310721, 0.05180858, 0.02558406], [-0.02836455, 0.02459503, -0.02708566, 0.00583999, 0.00872248, -0.00513982, -0.01471234, -0.02872357, 0.04395078, 0.02091765], [-0.03286977, 0.02876312, -0.03215162, 0.00660443, 0.00964759, -0.00513511, -0.01779803, -0.03295097, 0.05082242, 0.02506794], [-0.03354837, 0.02899452, -0.03351285, 0.00706999, 0.0099633 , -0.00600172, -0.01829627, -0.03385407, 0.05331197, 0.02587351], [-0.03198106, 0.02785144, -0.03112133, 0.0059883 , 0.00967887, -0.0056387 , -0.01676488, -0.03164117, 0.04967905, 0.02394948], [-0.03225692, 0.02726237, -0.03188002, 0.00685312, 0.01070373, -0.00560295, -0.01611813, -0.03215632, 0.05007988, 0.02311524], [-0.03241688, 0.02812029, -0.03135574, 0.00548618, 0.00983476, -0.00521725, -0.01655165, -0.03041096, 0.0495362 , 0.02297505], [-0.03549795, 0.02997884, -0.0341365 , 0.00668125, 0.0109898 , -0.00634468, -0.01858077, -0.03388178, 0.05498474, 0.02580706], [-0.0333684 , 0.02838956, -0.03212935, 0.00664613, 0.01064295, -0.00593044, -0.01698458, -0.03180736, 0.05064521, 0.02389629], [-0.03332377, 0.02806536, -0.03194948, 0.00700422, 0.01063697, -0.00575987, -0.01724805, -0.03318173, 0.05052466, 0.02523168], [-0.03010449, 0.02509271, -0.0283036 , 0.00543095, 0.00973582, -0.00523325, -0.01575454, -0.02942076, 0.04668282, 0.02187435], [-0.03340966, 0.02933937, -0.03271718, 0.00622493, 0.01087594, -0.006915 , -0.0171087 , -0.03330728, 0.05251062, 0.02450696], [-0.03314458, 0.02756864, -0.03136821, 0.0067043 , 0.01050329, -0.00576743, -0.01729282, -0.03131901, 0.04967041, 0.02444542], [-0.02791689, 0.0236672 , -0.0269364 , 0.00649901, 0.00876993, -0.00482635, -0.01450332, -0.02729476, 0.04232563, 0.02021596], [-0.03212791, 0.02760413, -0.03230463, 0.00700643, 0.0095329 , -0.00493933, -0.01729792, -0.03130945, 0.05007241, 0.02376336], [-0.03375838, 0.02913737, -0.03275585, 0.00604866, 0.0097856 , -0.00481674, -0.01891366, -0.03191204, 0.05180573, 0.02537931], [-0.03373562, 0.02829822, -0.03284443, 0.00696967, 0.01020142, -0.00671042, -0.01698536, -0.03204993, 0.05255247, 0.02430396], [-0.03125659, 0.02611411, -0.0296259 , 0.00608051, 0.00961873, -0.00584174, -0.01602322, -0.02986114, 0.04787214, 0.0229231 ], [-0.03602021, 0.03000653, -0.0351347 , 0.00660656, 0.01051215, -0.00544103, -0.01828199, -0.03403504, 0.05506651, 0.02672122], [-0.03371491, 0.02850074, -0.03322041, 0.00621146, 0.00978465, -0.00481382, -0.01762117, -0.03269777, 0.05250961, 0.02506162], [-0.03539243, 0.03019005, -0.03570016, 0.0077133 , 0.01050243, -0.00549932, -0.01942851, -0.03403616, 0.05470694, 0.02694387], [-0.03538167, 0.0309353 , -0.03510876, 0.00680452, 0.01125513, -0.00615746, -0.0188463 , -0.035316 , 0.05550335, 0.02631189], [-0.03442939, 0.02993247, -0.03237467, 0.00740703, 0.00931064, -0.00612056, -0.01783796, -0.03388358, 0.05306398, 0.02493203], [-0.02866022, 0.02402509, -0.02707872, 0.00610466, 0.00898028, -0.00533369, -0.01397438, -0.02866188, 0.04400777, 0.02059111], [-0.03368613, 0.0284892 , -0.03311794, 0.00667779, 0.00974075, -0.00539205, -0.01695743, -0.03194102, 0.05107447, 0.02511235], [-0.0338799 , 0.02841247, -0.03331895, 0.00735235, 0.01048246, -0.0062042 , -0.01682675, -0.03344237, 0.05236557, 0.02505931], [-0.03140692, 0.02633876, -0.02869979, 0.00529764, 0.00918932, -0.00515628, -0.01536006, -0.02930658, 0.04674755, 0.02235637], [-0.03261945, 0.02703528, -0.03096781, 0.00640773, 0.00994492, -0.0053254 , -0.01653287, -0.03065078, 0.04885715, 0.02385123], [-0.03214316, 0.02701741, -0.03101326, 0.00687427, 0.00934358, -0.00586959, -0.01600252, -0.03060755, 0.05002867, 0.02237214], [-0.02992665, 0.02475511, -0.02713624, 0.00611988, 0.00935454, -0.00529203, -0.01545367, -0.02855624, 0.04462555, 0.02150975], [-0.02831092, 0.02448312, -0.02740556, 0.00572162, 0.00785943, -0.00418882, -0.01441121, -0.02831959, 0.04371619, 0.02085574], [-0.03007917, 0.02558904, -0.02931532, 0.00539514, 0.00951177, -0.00513249, -0.01557556, -0.02884621, 0.04701886, 0.02143393], [-0.03266611, 0.02787502, -0.03101623, 0.0063059 , 0.00954711, -0.0061714 , -0.01580115, -0.03134446, 0.05044341, 0.02282789], [-0.03229678, 0.02842157, -0.03235689, 0.00683963, 0.00946099, -0.00546106, -0.01703141, -0.03303495, 0.05186149, 0.02359742], [-0.03382409, 0.02905957, -0.03254269, 0.00687754, 0.00947968, -0.00622799, -0.01697716, -0.03223169, 0.05174077, 0.02464607], [-0.03332122, 0.02801672, -0.03074191, 0.00684711, 0.00924434, -0.00647087, -0.01756576, -0.03210407, 0.05086429, 0.02523136], [-0.0312501 , 0.02628983, -0.02918963, 0.00584961, 0.00898994, -0.00452294, -0.01738409, -0.0302087 , 0.04761856, 0.02380753], [-0.02930271, 0.0255147 , -0.02848147, 0.0054396 , 0.00924454, -0.00386412, -0.01641887, -0.02857421, 0.04437237, 0.02207017], [-0.03107337, 0.02567611, -0.02951677, 0.00595258, 0.00843881, -0.00471157, -0.01603934, -0.02972317, 0.04815231, 0.02284441], [-0.02988364, 0.02552669, -0.02867886, 0.00570943, 0.01020839, -0.0051748 , -0.01607467, -0.02965776, 0.04594088, 0.02208434], [-0.0318397 , 0.0264761 , -0.03058091, 0.00625383, 0.01070792, -0.00519583, -0.01592099, -0.03091565, 0.04792184, 0.02309338], [-0.03352543, 0.02926421, -0.03307524, 0.0066982 , 0.00950219, -0.00528628, -0.01786752, -0.03377278, 0.05288171, 0.02518093], [-0.03337192, 0.02834321, -0.03386094, 0.00690921, 0.01149573, -0.00597221, -0.01834335, -0.03391577, 0.05288185, 0.02583419], [-0.03290418, 0.02816869, -0.03097322, 0.00620668, 0.00916017, -0.00734834, -0.01603954, -0.03090502, 0.0504636 , 0.02417116], [-0.03142334, 0.02780932, -0.03031838, 0.00628931, 0.00967489, -0.00423896, -0.01776309, -0.0316775 , 0.04816087, 0.02348687], [-0.0323697 , 0.02723762, -0.03103106, 0.00573535, 0.00962562, -0.00522118, -0.01731385, -0.03092151, 0.05090904, 0.02334968], [-0.03411187, 0.02953637, -0.03333079, 0.00684854, 0.01082413, -0.00611057, -0.0182508 , -0.03304469, 0.0525314 , 0.02510829], [-0.03267471, 0.02769343, -0.03174318, 0.00760364, 0.00917385, -0.00540552, -0.01778894, -0.03295473, 0.05167967, 0.02441648], [-0.02682885, 0.02286263, -0.02654839, 0.00466919, 0.00889806, -0.00363118, -0.0141405 , -0.02601186, 0.04038026, 0.02035065], [-0.03298302, 0.02923848, -0.03259903, 0.00556848, 0.00982333, -0.00540997, -0.01715113, -0.03269295, 0.05160448, 0.02460132], [-0.02909083, 0.02469331, -0.02753569, 0.00556105, 0.00919983, -0.00555487, -0.01537191, -0.02875262, 0.04507926, 0.02177248], [-0.0327072 , 0.02753722, -0.03175176, 0.00662967, 0.00976673, -0.00505499, -0.01771411, -0.03233897, 0.05109376, 0.02453963], [-0.03444812, 0.02893125, -0.03379944, 0.00578092, 0.01104381, -0.00596671, -0.01655321, -0.03170244, 0.05169574, 0.02501821], [-0.03281513, 0.02838266, -0.03203545, 0.00703576, 0.00983978, -0.00571333, -0.01685194, -0.03148604, 0.0497263 , 0.0239174 ], [-0.03397787, 0.02919245, -0.03304477, 0.00609904, 0.01101417, -0.00630255, -0.01813639, -0.03214453, 0.05285554, 0.02444491], [-0.03551171, 0.02873089, -0.03234695, 0.00667004, 0.01161269, -0.0066024 , -0.01799809, -0.03303289, 0.05289704, 0.02558137], [-0.03345789, 0.02805731, -0.03216191, 0.0066332 , 0.01110583, -0.00668819, -0.01729537, -0.03171502, 0.05134273, 0.02417932], [-0.03571531, 0.03177013, -0.03498455, 0.00626511, 0.01042538, -0.00699867, -0.01937796, -0.03530985, 0.05753166, 0.02639405], [-0.03071942, 0.02630471, -0.02964189, 0.00570243, 0.00995084, -0.00487006, -0.0161814 , -0.03020009, 0.04778441, 0.02187048], [-0.03365128, 0.02929494, -0.03307013, 0.00635135, 0.01120298, -0.00555466, -0.01911942, -0.03370147, 0.05315707, 0.02509062], [-0.03323815, 0.02909311, -0.03249018, 0.00619802, 0.01061225, -0.00654409, -0.01836123, -0.03301659, 0.05221094, 0.02553593]]), 'b1': array([ 9.72244178e-04, -2.10516289e-04, 8.06499623e-04, -4.54452058e-04, 1.03156506e-04, 5.78729837e-04, -4.47715980e-04, -1.31153719e-04, 7.02388725e-05, 1.83164666e-04, 5.85322102e-05, 1.85074382e-04, -2.11243032e-04, 1.05337970e-03, 6.36353699e-04, 3.31976935e-04, 2.57807464e-05, -7.88754750e-05, -2.39597462e-04, -5.87997890e-04, 1.20773258e-04, 1.45807508e-04, -2.07594733e-04, 5.94327254e-06, 5.06617441e-04, -4.48362734e-04, -2.36052253e-04, 6.14401685e-05, 3.78585936e-04, 7.43117745e-04, 7.77693909e-05, 4.21173265e-04, 9.23945143e-04, 4.72733781e-04, 5.98088814e-04, -7.56935989e-04, 5.16780592e-04, 7.64812924e-05, -4.64091068e-04, 3.93122912e-04, 5.38132534e-04, 4.35375025e-04, 5.14997938e-04, -3.97541839e-04, 5.27583035e-04, 2.77897207e-04, 3.31537384e-04, -1.03570077e-04, 4.67731613e-04, 7.57756355e-04, 1.62951332e-04, 1.98281107e-04, 4.72577852e-04, 5.12691587e-04, 8.37634406e-05, -4.05915230e-04, -4.30050138e-04, 5.55485968e-04, 1.69313212e-04, 2.35962139e-04, 3.70075899e-04, -1.70174008e-04, -2.66616329e-04, -2.39738664e-04, 1.20414425e-04, 1.15750600e-04, 4.94900121e-05, -1.59238018e-04, 4.73938799e-04, 6.39885664e-04, -1.28604221e-04, 1.08085398e-04, -3.09770014e-04, 5.38120579e-04, 1.04430198e-04, 2.31152049e-04, -2.39551134e-05, 5.03123765e-04, 1.29665523e-05, 7.30886125e-04, 4.54309479e-04, -9.13621179e-06, 1.41973091e-04, 4.39805730e-04, 5.19062082e-04, 2.54709631e-04, 2.04762500e-04, 4.32227854e-05, 1.70664638e-05, 6.24581009e-05, 2.18631815e-04, -2.65530637e-04, -8.34974578e-05, 2.86758137e-04, 8.24308533e-05, -1.96366283e-04, -2.23455494e-04, 4.73181450e-04, 7.98598263e-04, 8.17088344e-04]), 'b2': array([-0.06582665, 0.05601371, -0.06356752, 0.01303664, 0.0200867 , -0.01138081, -0.03434326, -0.06420816, 0.10170829, 0.04848107])}
%matplotlib inline
import matplotlib.pyplot as plt
!date
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
iters_num = 10000 # 繰り返しの回数を適宜設定する
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
for i in range(iters_num):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
# 勾配の計算
#grad = network.numerical_gradient(x_batch, t_batch)
grad = network.gradient(x_batch, t_batch)
# パラメータの更新
for key in ('W1', 'b1', 'W2', 'b2'):
network.params[key] -= learning_rate * grad[key]
loss = network.loss(x_batch, t_batch)
train_loss_list.append(loss)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))
# グラフの描画
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
!date
Wed Feb 1 15:47:16 JST 2017 train acc, test acc | 0.0993, 0.1032 train acc, test acc | 0.789066666667, 0.796 train acc, test acc | 0.875516666667, 0.8806 train acc, test acc | 0.898766666667, 0.9007 train acc, test acc | 0.905566666667, 0.9065 train acc, test acc | 0.91295, 0.9158 train acc, test acc | 0.9173, 0.9192 train acc, test acc | 0.921083333333, 0.9225 train acc, test acc | 0.92585, 0.9271 train acc, test acc | 0.92865, 0.9295 train acc, test acc | 0.932033333333, 0.9329 train acc, test acc | 0.935066666667, 0.9348 train acc, test acc | 0.937766666667, 0.937 train acc, test acc | 0.9405, 0.9393 train acc, test acc | 0.94185, 0.9411 train acc, test acc | 0.943516666667, 0.9418 train acc, test acc | 0.94555, 0.9423
Wed Feb 1 15:47:52 JST 2017
class MulLayer:
def __init__(self):
self.x = None
self.y = None
def forward(self, x, y):
self.x = x
self.y = y
out = x * y
return out
def backward(self, dout):
dx = dout * self.y
dy = dout * self.x
return dx, dy
class AddLayer:
def __init__(self):
pass
def forward(self, x, y):
out = x + y
return out
def backward(self, dout):
dx = dout * 1
dy = dout * 1
return dx, dy
apple = 100
apple_num = 2
tax = 1.1
mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()
# forward
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)
# backward
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)
print("price:", int(price))
print("dApple:", dapple)
print("dApple_num:", int(dapple_num))
print("dTax:", dtax)
price: 220 dApple: 2.2 dApple_num: 110 dTax: 200
apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1
# layer
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()
# forward
apple_price = mul_apple_layer.forward(apple, apple_num) # (1)
orange_price = mul_orange_layer.forward(orange, orange_num) # (2)
all_price = add_apple_orange_layer.forward(apple_price, orange_price) # (3)
price = mul_tax_layer.forward(all_price, tax) # (4)
# backward
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice) # (4)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price) # (3)
dorange, dorange_num = mul_orange_layer.backward(dorange_price) # (2)
dapple, dapple_num = mul_apple_layer.backward(dapple_price) # (1)
print("price:", int(price))
print("dApple:", dapple)
print("dApple_num:", int(dapple_num))
print("dOrange:", dorange)
print("dOrange_num:", int(dorange_num))
print("dTax:", dtax)
price: 715 dApple: 2.2 dApple_num: 110 dOrange: 3.3000000000000003 dOrange_num: 165 dTax: 650
class Relu:
def __init__(self):
self.mask = None
def forward(self, x):
self.mask = (x <= 0)
out = x.copy()
out[self.mask] = 0
return out
def backward(self, dout):
dout[self.mask] = 0
dx = dout
return dx
class Sigmoid:
def __init__(self):
self.out = None
def forward(self, x):
out = sigmoid(x)
self.out = out
return out
def backward(self, dout):
dx = dout * (1.0 - self.out) * self.out
return dx
class Affine:
def __init__(self, W, b):
self.W =W
self.b = b
self.x = None
self.original_x_shape = None
# 重み・バイアスパラメータの微分
self.dW = None
self.db = None
def forward(self, x):
# テンソル対応
self.original_x_shape = x.shape
x = x.reshape(x.shape[0], -1)
self.x = x
out = np.dot(self.x, self.W) + self.b
return out
def backward(self, dout):
dx = np.dot(dout, self.W.T)
self.dW = np.dot(self.x.T, dout)
self.db = np.sum(dout, axis=0)
dx = dx.reshape(*self.original_x_shape) # 入力データの形状に戻す(テンソル対応)
return dx
class SoftmaxWithLoss:
def __init__(self):
self.loss = None
self.y = None # softmaxの出力
self.t = None # 教師データ
def forward(self, x, t):
self.t = t
self.y = softmax(x)
self.loss = cross_entropy_error(self.y, self.t)
return self.loss
def backward(self, dout=1):
batch_size = self.t.shape[0]
if self.t.size == self.y.size: # 教師データがone-hot-vectorの場合
dx = (self.y - self.t) / batch_size
else:
dx = self.y.copy()
dx[np.arange(batch_size), self.t] -= 1
dx = dx / batch_size
return dx
from collections import OrderedDict
class TwoLayerNet:
def __init__(self, input_size, hidden_size, output_size, weight_init_std = 0.01):
# 重みの初期化
self.params = {}
self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
self.params['b1'] = np.zeros(hidden_size)
self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
self.params['b2'] = np.zeros(output_size)
# レイヤの生成
self.layers = OrderedDict()
self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])
self.layers['Relu1'] = Relu()
self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])
self.lastLayer = SoftmaxWithLoss()
def predict(self, x):
for layer in self.layers.values():
x = layer.forward(x)
return x
# x:入力データ, t:教師データ
def loss(self, x, t):
y = self.predict(x)
return self.lastLayer.forward(y, t)
def accuracy(self, x, t):
y = self.predict(x)
y = np.argmax(y, axis=1)
if t.ndim != 1 : t = np.argmax(t, axis=1)
accuracy = np.sum(y == t) / float(x.shape[0])
return accuracy
# x:入力データ, t:教師データ
def numerical_gradient(self, x, t):
loss_W = lambda W: self.loss(x, t)
grads = {}
grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
return grads
def gradient(self, x, t):
# forward
self.loss(x, t)
# backward
dout = 1
dout = self.lastLayer.backward(dout)
layers = list(self.layers.values())
layers.reverse()
for layer in layers:
dout = layer.backward(dout)
# 設定
grads = {}
grads['W1'], grads['b1'] = self.layers['Affine1'].dW, self.layers['Affine1'].db
grads['W2'], grads['b2'] = self.layers['Affine2'].dW, self.layers['Affine2'].db
return grads
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
x_batch = x_train[:3]
t_batch = t_train[:3]
grad_numerical = network.numerical_gradient(x_batch, t_batch)
grad_backprop = network.gradient(x_batch, t_batch)
for key in grad_numerical.keys():
diff = np.average( np.abs(grad_backprop[key] - grad_numerical[key]) )
print(key + ":" + str(diff))
W1:2.52951812473e-13 b1:1.09754104026e-12 W2:9.1445718356e-13 b2:1.20348174482e-10
!date
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
iters_num = 10000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
for i in range(iters_num):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
# 勾配
#grad = network.numerical_gradient(x_batch, t_batch)
grad = network.gradient(x_batch, t_batch)
# 更新
for key in ('W1', 'b1', 'W2', 'b2'):
network.params[key] -= learning_rate * grad[key]
loss = network.loss(x_batch, t_batch)
train_loss_list.append(loss)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print(train_acc, test_acc)
!date
Wed Feb 1 16:27:22 JST 2017 0.14105 0.1497 0.902283333333 0.9043 0.9222 0.9235 0.935183333333 0.9334 0.943066666667 0.94 0.949916666667 0.9477 0.955633333333 0.9539 0.9594 0.957 0.961283333333 0.9593 0.963433333333 0.9583 0.966233333333 0.9627 0.9692 0.9653 0.9713 0.9656 0.972316666667 0.9667 0.974016666667 0.9668 0.97575 0.9678 0.975566666667 0.9676 Wed Feb 1 16:28:06 JST 2017
# グラフの描画
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
!date
Wed Feb 1 16:28:15 JST 2017
class SGD:
"""確率的勾配降下法(Stochastic Gradient Descent)"""
def __init__(self, lr=0.01):
self.lr = lr
def update(self, params, grads):
for key in params.keys():
params[key] -= self.lr * grads[key]
class Momentum:
"""Momentum SGD"""
def __init__(self, lr=0.01, momentum=0.9):
self.lr = lr
self.momentum = momentum
self.v = None
def update(self, params, grads):
if self.v is None:
self.v = {}
for key, val in params.items():
self.v[key] = np.zeros_like(val)
for key in params.keys():
self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]
params[key] += self.v[key]
class AdaGrad:
"""AdaGrad"""
def __init__(self, lr=0.01):
self.lr = lr
self.h = None
def update(self, params, grads):
if self.h is None:
self.h = {}
for key, val in params.items():
self.h[key] = np.zeros_like(val)
for key in params.keys():
self.h[key] += grads[key] * grads[key]
params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
class Adam:
"""Adam (http://arxiv.org/abs/1412.6980v8)"""
def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.iter = 0
self.m = None
self.v = None
def update(self, params, grads):
if self.m is None:
self.m, self.v = {}, {}
for key, val in params.items():
self.m[key] = np.zeros_like(val)
self.v[key] = np.zeros_like(val)
self.iter += 1
lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)
for key in params.keys():
#self.m[key] = self.beta1*self.m[key] + (1-self.beta1)*grads[key]
#self.v[key] = self.beta2*self.v[key] + (1-self.beta2)*(grads[key]**2)
self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])
self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])
params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)
#unbias_m += (1 - self.beta1) * (grads[key] - self.m[key]) # correct bias
#unbisa_b += (1 - self.beta2) * (grads[key]*grads[key] - self.v[key]) # correct bias
#params[key] += self.lr * unbias_m / (np.sqrt(unbisa_b) + 1e-7)
def f(x, y):
return x**2 / 20.0 + y**2
def df(x, y):
return x / 10.0, 2.0*y
init_pos = (-7.0, 2.0)
params = {}
params['x'], params['y'] = init_pos[0], init_pos[1]
grads = {}
grads['x'], grads['y'] = 0, 0
optimizers = OrderedDict()
optimizers["SGD"] = SGD(lr=0.95)
optimizers["Momentum"] = Momentum(lr=0.1)
optimizers["AdaGrad"] = AdaGrad(lr=1.5)
optimizers["Adam"] = Adam(lr=0.3)
idx = 1
for key in optimizers:
optimizer = optimizers[key]
x_history = []
y_history = []
params['x'], params['y'] = init_pos[0], init_pos[1]
for i in range(30):
x_history.append(params['x'])
y_history.append(params['y'])
grads['x'], grads['y'] = df(params['x'], params['y'])
optimizer.update(params, grads)
x = np.arange(-10, 10, 0.01)
y = np.arange(-5, 5, 0.01)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
# for simple contour line
mask = Z > 7
Z[mask] = 0
# plot
plt.subplot(2, 2, idx)
idx += 1
plt.plot(x_history, y_history, 'o-', color="red")
plt.contour(X, Y, Z)
plt.ylim(-10, 10)
plt.xlim(-10, 10)
plt.plot(0, 0, '+')
#colorbar()
#spring()
plt.title(key)
plt.xlabel("x")
plt.ylabel("y")
plt.show()
class MultiLayerNet:
"""全結合による多層ニューラルネットワーク
Parameters
----------
input_size : 入力サイズ(MNISTの場合は784)
hidden_size_list : 隠れ層のニューロンの数のリスト(e.g. [100, 100, 100])
output_size : 出力サイズ(MNISTの場合は10)
activation : 'relu' or 'sigmoid'
weight_init_std : 重みの標準偏差を指定(e.g. 0.01)
'relu'または'he'を指定した場合は「Heの初期値」を設定
'sigmoid'または'xavier'を指定した場合は「Xavierの初期値」を設定
weight_decay_lambda : Weight Decay(L2ノルム)の強さ
"""
def __init__(self, input_size, hidden_size_list, output_size,
activation='relu', weight_init_std='relu', weight_decay_lambda=0):
self.input_size = input_size
self.output_size = output_size
self.hidden_size_list = hidden_size_list
self.hidden_layer_num = len(hidden_size_list)
self.weight_decay_lambda = weight_decay_lambda
self.params = {}
# 重みの初期化
self.__init_weight(weight_init_std)
# レイヤの生成
activation_layer = {'sigmoid': Sigmoid, 'relu': Relu}
self.layers = OrderedDict()
for idx in range(1, self.hidden_layer_num+1):
self.layers['Affine' + str(idx)] = Affine(self.params['W' + str(idx)],
self.params['b' + str(idx)])
self.layers['Activation_function' + str(idx)] = activation_layer[activation]()
idx = self.hidden_layer_num + 1
self.layers['Affine' + str(idx)] = Affine(self.params['W' + str(idx)],
self.params['b' + str(idx)])
self.last_layer = SoftmaxWithLoss()
def __init_weight(self, weight_init_std):
"""重みの初期値設定
Parameters
----------
weight_init_std : 重みの標準偏差を指定(e.g. 0.01)
'relu'または'he'を指定した場合は「Heの初期値」を設定
'sigmoid'または'xavier'を指定した場合は「Xavierの初期値」を設定
"""
all_size_list = [self.input_size] + self.hidden_size_list + [self.output_size]
for idx in range(1, len(all_size_list)):
scale = weight_init_std
if str(weight_init_std).lower() in ('relu', 'he'):
scale = np.sqrt(2.0 / all_size_list[idx - 1]) # ReLUを使う場合に推奨される初期値
elif str(weight_init_std).lower() in ('sigmoid', 'xavier'):
scale = np.sqrt(1.0 / all_size_list[idx - 1]) # sigmoidを使う場合に推奨される初期値
self.params['W' + str(idx)] = scale * np.random.randn(all_size_list[idx-1], all_size_list[idx])
self.params['b' + str(idx)] = np.zeros(all_size_list[idx])
def predict(self, x):
for layer in self.layers.values():
x = layer.forward(x)
return x
def loss(self, x, t):
"""損失関数を求める
Parameters
----------
x : 入力データ
t : 教師ラベル
Returns
-------
損失関数の値
"""
y = self.predict(x)
weight_decay = 0
for idx in range(1, self.hidden_layer_num + 2):
W = self.params['W' + str(idx)]
weight_decay += 0.5 * self.weight_decay_lambda * np.sum(W ** 2)
return self.last_layer.forward(y, t) + weight_decay
def accuracy(self, x, t):
y = self.predict(x)
y = np.argmax(y, axis=1)
if t.ndim != 1 : t = np.argmax(t, axis=1)
accuracy = np.sum(y == t) / float(x.shape[0])
return accuracy
def numerical_gradient(self, x, t):
"""勾配を求める(数値微分)
Parameters
----------
x : 入力データ
t : 教師ラベル
Returns
-------
各層の勾配を持ったディクショナリ変数
grads['W1']、grads['W2']、...は各層の重み
grads['b1']、grads['b2']、...は各層のバイアス
"""
loss_W = lambda W: self.loss(x, t)
grads = {}
for idx in range(1, self.hidden_layer_num+2):
grads['W' + str(idx)] = numerical_gradient(loss_W, self.params['W' + str(idx)])
grads['b' + str(idx)] = numerical_gradient(loss_W, self.params['b' + str(idx)])
return grads
def gradient(self, x, t):
"""勾配を求める(誤差逆伝搬法)
Parameters
----------
x : 入力データ
t : 教師ラベル
Returns
-------
各層の勾配を持ったディクショナリ変数
grads['W1']、grads['W2']、...は各層の重み
grads['b1']、grads['b2']、...は各層のバイアス
"""
# forward
self.loss(x, t)
# backward
dout = 1
dout = self.last_layer.backward(dout)
layers = list(self.layers.values())
layers.reverse()
for layer in layers:
dout = layer.backward(dout)
# 設定
grads = {}
for idx in range(1, self.hidden_layer_num+2):
grads['W' + str(idx)] = self.layers['Affine' + str(idx)].dW + self.weight_decay_lambda * self.layers['Affine' + str(idx)].W
grads['b' + str(idx)] = self.layers['Affine' + str(idx)].db
return grads
def smooth_curve(x):
"""損失関数のグラフを滑らかにするために用いる
参考:http://glowingpython.blogspot.jp/2012/02/convolution-with-numpy.html
"""
window_len = 11
s = np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]]
w = np.kaiser(window_len, 2)
y = np.convolve(w/w.sum(), s, mode='valid')
return y[5:len(y)-5]
# 0:MNISTデータの読み込み==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000
# 1:実験の設定==========
optimizers = {}
optimizers['SGD'] = SGD()
optimizers['Momentum'] = Momentum()
optimizers['AdaGrad'] = AdaGrad()
optimizers['Adam'] = Adam()
#optimizers['RMSprop'] = RMSprop()
networks = {}
train_loss = {}
for key in optimizers.keys():
networks[key] = MultiLayerNet(
input_size=784, hidden_size_list=[100, 100, 100, 100],
output_size=10)
train_loss[key] = []
# 2:訓練の開始==========
for i in range(max_iterations):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
for key in optimizers.keys():
grads = networks[key].gradient(x_batch, t_batch)
optimizers[key].update(networks[key].params, grads)
loss = networks[key].loss(x_batch, t_batch)
train_loss[key].append(loss)
if i % 100 == 0:
print( "===========" + "iteration:" + str(i) + "===========")
for key in optimizers.keys():
loss = networks[key].loss(x_batch, t_batch)
print(key + ":" + str(loss))
# 3.グラフの描画==========
markers = {"SGD": "o", "Momentum": "x", "AdaGrad": "s", "Adam": "D"}
x = np.arange(max_iterations)
for key in optimizers.keys():
plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 1)
plt.legend()
plt.show()
===========iteration:0=========== SGD:2.36758287946 Momentum:2.35676778844 AdaGrad:2.28180587728 Adam:2.21227121869 ===========iteration:100=========== SGD:1.38855685254 Momentum:0.461870640267 AdaGrad:0.211115915906 Adam:0.315237686933 ===========iteration:200=========== SGD:0.742795168146 Momentum:0.357434278872 AdaGrad:0.202908190167 Adam:0.298717102536 ===========iteration:300=========== SGD:0.510175628236 Momentum:0.243544097013 AdaGrad:0.115050226278 Adam:0.130153141426 ===========iteration:400=========== SGD:0.331651212545 Momentum:0.168092310964 AdaGrad:0.0878949825684 Adam:0.139117592448 ===========iteration:500=========== SGD:0.334500266159 Momentum:0.141800168398 AdaGrad:0.0582819956473 Adam:0.0924050059469 ===========iteration:600=========== SGD:0.374285241135 Momentum:0.18735173572 AdaGrad:0.0843068950526 Adam:0.113304021182 ===========iteration:700=========== SGD:0.392008858143 Momentum:0.25524217641 AdaGrad:0.0809886331434 Adam:0.149813645512 ===========iteration:800=========== SGD:0.359701837535 Momentum:0.158389731597 AdaGrad:0.0627873663215 Adam:0.0977498236121 ===========iteration:900=========== SGD:0.355107337708 Momentum:0.145141113182 AdaGrad:0.0792770906671 Adam:0.0922878877236 ===========iteration:1000=========== SGD:0.299876043291 Momentum:0.101840715714 AdaGrad:0.0507441246263 Adam:0.0556904090103 ===========iteration:1100=========== SGD:0.155368625734 Momentum:0.10524641913 AdaGrad:0.0575273175541 Adam:0.0409443183873 ===========iteration:1200=========== SGD:0.404946580285 Momentum:0.0887576738966 AdaGrad:0.0483943120411 Adam:0.0516832823355 ===========iteration:1300=========== SGD:0.161848101422 Momentum:0.0572392158164 AdaGrad:0.016754523541 Adam:0.0439077479572 ===========iteration:1400=========== SGD:0.306202349484 Momentum:0.107992380205 AdaGrad:0.0547485069623 Adam:0.0654920911761 ===========iteration:1500=========== SGD:0.259560793717 Momentum:0.111217613058 AdaGrad:0.0409710197983 Adam:0.043987273782 ===========iteration:1600=========== SGD:0.242831500432 Momentum:0.0525726596907 AdaGrad:0.0418354154041 Adam:0.0456010996183 ===========iteration:1700=========== SGD:0.352313053763 Momentum:0.0850253266804 AdaGrad:0.0285429751321 Adam:0.0450659573381 ===========iteration:1800=========== SGD:0.185902877271 Momentum:0.0916553598803 AdaGrad:0.0270357861851 Adam:0.0126373417587 ===========iteration:1900=========== SGD:0.149110393347 Momentum:0.0417008407212 AdaGrad:0.0286249824354 Adam:0.0254952933551
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def ReLU(x):
return np.maximum(0, x)
def tanh(x):
return np.tanh(x)
input_data = np.random.randn(1000, 100) # 1000個のデータ
node_num = 100 # 各隠れ層のノード(ニューロン)の数
hidden_layer_size = 5 # 隠れ層が5層
activations = {} # ここにアクティベーションの結果を格納する
x = input_data
for i in range(hidden_layer_size):
if i != 0:
x = activations[i-1]
# 初期値の値をいろいろ変えて実験しよう!
w = np.random.randn(node_num, node_num) * 1
# w = np.random.randn(node_num, node_num) * 0.01
# w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
# w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)
a = np.dot(x, w)
# 活性化関数の種類も変えて実験しよう!
z = sigmoid(a)
# z = ReLU(a)
# z = tanh(a)
activations[i] = z
# ヒストグラムを描画
for i, a in activations.items():
plt.subplot(1, len(activations), i+1)
plt.title(str(i+1) + "-layer")
if i != 0: plt.yticks([], [])
# plt.xlim(0.1, 1)
# plt.ylim(0, 7000)
plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def ReLU(x):
return np.maximum(0, x)
def tanh(x):
return np.tanh(x)
input_data = np.random.randn(1000, 100) # 1000個のデータ
node_num = 100 # 各隠れ層のノード(ニューロン)の数
hidden_layer_size = 5 # 隠れ層が5層
activations = {} # ここにアクティベーションの結果を格納する
x = input_data
for i in range(hidden_layer_size):
if i != 0:
x = activations[i-1]
# 初期値の値をいろいろ変えて実験しよう!
# w = np.random.randn(node_num, node_num) * 1
w = np.random.randn(node_num, node_num) * 0.01
# w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
# w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)
a = np.dot(x, w)
# 活性化関数の種類も変えて実験しよう!
z = sigmoid(a)
# z = ReLU(a)
# z = tanh(a)
activations[i] = z
# ヒストグラムを描画
for i, a in activations.items():
plt.subplot(1, len(activations), i+1)
plt.title(str(i+1) + "-layer")
if i != 0: plt.yticks([], [])
# plt.xlim(0.1, 1)
# plt.ylim(0, 7000)
plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def ReLU(x):
return np.maximum(0, x)
def tanh(x):
return np.tanh(x)
input_data = np.random.randn(1000, 100) # 1000個のデータ
node_num = 100 # 各隠れ層のノード(ニューロン)の数
hidden_layer_size = 5 # 隠れ層が5層
activations = {} # ここにアクティベーションの結果を格納する
x = input_data
for i in range(hidden_layer_size):
if i != 0:
x = activations[i-1]
# 初期値の値をいろいろ変えて実験しよう!
# w = np.random.randn(node_num, node_num) * 1
# w = np.random.randn(node_num, node_num) * 0.01
w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
# w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)
a = np.dot(x, w)
# 活性化関数の種類も変えて実験しよう!
z = sigmoid(a)
# z = ReLU(a)
# z = tanh(a)
activations[i] = z
# ヒストグラムを描画
for i, a in activations.items():
plt.subplot(1, len(activations), i+1)
plt.title(str(i+1) + "-layer")
if i != 0: plt.yticks([], [])
# plt.xlim(0.1, 1)
# plt.ylim(0, 7000)
plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def ReLU(x):
return np.maximum(0, x)
def tanh(x):
return np.tanh(x)
input_data = np.random.randn(1000, 100) # 1000個のデータ
node_num = 100 # 各隠れ層のノード(ニューロン)の数
hidden_layer_size = 5 # 隠れ層が5層
activations = {} # ここにアクティベーションの結果を格納する
x = input_data
for i in range(hidden_layer_size):
if i != 0:
x = activations[i-1]
# 初期値の値をいろいろ変えて実験しよう!
# w = np.random.randn(node_num, node_num) * 1
# w = np.random.randn(node_num, node_num) * 0.01
# w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)
a = np.dot(x, w)
# 活性化関数の種類も変えて実験しよう!
z = sigmoid(a)
# z = ReLU(a)
# z = tanh(a)
activations[i] = z
# ヒストグラムを描画
for i, a in activations.items():
plt.subplot(1, len(activations), i+1)
plt.title(str(i+1) + "-layer")
if i != 0: plt.yticks([], [])
# plt.xlim(0.1, 1)
# plt.ylim(0, 7000)
plt.hist(a.flatten(), 30, range=(0,1))
plt.show()
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000
# 1:実験の設定==========
weight_init_types = {'std=0.01': 0.01, 'Xavier': 'sigmoid', 'He': 'relu'}
optimizer = SGD(lr=0.01)
networks = {}
train_loss = {}
for key, weight_type in weight_init_types.items():
networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],
output_size=10, weight_init_std=weight_type)
train_loss[key] = []
# 2:訓練の開始==========
for i in range(max_iterations):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
for key in weight_init_types.keys():
grads = networks[key].gradient(x_batch, t_batch)
optimizer.update(networks[key].params, grads)
loss = networks[key].loss(x_batch, t_batch)
train_loss[key].append(loss)
if i % 100 == 0:
print("===========" + "iteration:" + str(i) + "===========")
for key in weight_init_types.keys():
loss = networks[key].loss(x_batch, t_batch)
print(key + ":" + str(loss))
# 3.グラフの描画==========
markers = {'std=0.01': 'o', 'Xavier': 's', 'He': 'D'}
x = np.arange(max_iterations)
for key in weight_init_types.keys():
plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 2.5)
plt.legend()
plt.show()
===========iteration:0=========== std=0.01:2.30247970912 Xavier:2.31785863418 He:2.39733102841 ===========iteration:100=========== std=0.01:2.30265410099 Xavier:2.2260916212 He:1.56083317506 ===========iteration:200=========== std=0.01:2.30127066276 Xavier:2.02348908912 He:0.877107765076 ===========iteration:300=========== std=0.01:2.30014016663 Xavier:1.47061707486 He:0.503355627607 ===========iteration:400=========== std=0.01:2.30317732165 Xavier:1.01448959406 He:0.46532958735 ===========iteration:500=========== std=0.01:2.2999843498 Xavier:0.629062750472 He:0.356675670914 ===========iteration:600=========== std=0.01:2.30476419675 Xavier:0.56507387768 He:0.317460311184 ===========iteration:700=========== std=0.01:2.30169904812 Xavier:0.423849159431 He:0.237480920752 ===========iteration:800=========== std=0.01:2.30462699762 Xavier:0.482986535346 He:0.315248057758 ===========iteration:900=========== std=0.01:2.2939850946 Xavier:0.335377873382 He:0.243736603279 ===========iteration:1000=========== std=0.01:2.3019330985 Xavier:0.331399077165 He:0.251106529254 ===========iteration:1100=========== std=0.01:2.30603321804 Xavier:0.454444741848 He:0.358556085083 ===========iteration:1200=========== std=0.01:2.30795123028 Xavier:0.242326469941 He:0.168018775658 ===========iteration:1300=========== std=0.01:2.29942535812 Xavier:0.383945484494 He:0.292905878093 ===========iteration:1400=========== std=0.01:2.30093330509 Xavier:0.339810814097 He:0.274116232034 ===========iteration:1500=========== std=0.01:2.30644293511 Xavier:0.282457581235 He:0.198203544087 ===========iteration:1600=========== std=0.01:2.29968665162 Xavier:0.305759304047 He:0.192240097938 ===========iteration:1700=========== std=0.01:2.30052529681 Xavier:0.359125766982 He:0.321194203252 ===========iteration:1800=========== std=0.01:2.30342191874 Xavier:0.318089330676 He:0.218468892217 ===========iteration:1900=========== std=0.01:2.29904458402 Xavier:0.17868654387 He:0.103663457598
class MultiLayerNetExtend:
"""拡張版の全結合による多層ニューラルネットワーク
Weiht Decay、Dropout、Batch Normalizationの機能を持つ
Parameters
----------
input_size : 入力サイズ(MNISTの場合は784)
hidden_size_list : 隠れ層のニューロンの数のリスト(e.g. [100, 100, 100])
output_size : 出力サイズ(MNISTの場合は10)
activation : 'relu' or 'sigmoid'
weight_init_std : 重みの標準偏差を指定(e.g. 0.01)
'relu'または'he'を指定した場合は「Heの初期値」を設定
'sigmoid'または'xavier'を指定した場合は「Xavierの初期値」を設定
weight_decay_lambda : Weight Decay(L2ノルム)の強さ
use_dropout: Dropoutを使用するかどうか
dropout_ration : Dropoutの割り合い
use_batchNorm: Batch Normalizationを使用するかどうか
"""
def __init__(self, input_size, hidden_size_list, output_size,
activation='relu', weight_init_std='relu', weight_decay_lambda=0,
use_dropout = False, dropout_ration = 0.5, use_batchnorm=False):
self.input_size = input_size
self.output_size = output_size
self.hidden_size_list = hidden_size_list
self.hidden_layer_num = len(hidden_size_list)
self.use_dropout = use_dropout
self.weight_decay_lambda = weight_decay_lambda
self.use_batchnorm = use_batchnorm
self.params = {}
# 重みの初期化
self.__init_weight(weight_init_std)
# レイヤの生成
activation_layer = {'sigmoid': Sigmoid, 'relu': Relu}
self.layers = OrderedDict()
for idx in range(1, self.hidden_layer_num+1):
self.layers['Affine' + str(idx)] = Affine(self.params['W' + str(idx)],
self.params['b' + str(idx)])
if self.use_batchnorm:
self.params['gamma' + str(idx)] = np.ones(hidden_size_list[idx-1])
self.params['beta' + str(idx)] = np.zeros(hidden_size_list[idx-1])
self.layers['BatchNorm' + str(idx)] = BatchNormalization(self.params['gamma' + str(idx)], self.params['beta' + str(idx)])
self.layers['Activation_function' + str(idx)] = activation_layer[activation]()
if self.use_dropout:
self.layers['Dropout' + str(idx)] = Dropout(dropout_ration)
idx = self.hidden_layer_num + 1
self.layers['Affine' + str(idx)] = Affine(self.params['W' + str(idx)], self.params['b' + str(idx)])
self.last_layer = SoftmaxWithLoss()
def __init_weight(self, weight_init_std):
"""重みの初期値設定
Parameters
----------
weight_init_std : 重みの標準偏差を指定(e.g. 0.01)
'relu'または'he'を指定した場合は「Heの初期値」を設定
'sigmoid'または'xavier'を指定した場合は「Xavierの初期値」を設定
"""
all_size_list = [self.input_size] + self.hidden_size_list + [self.output_size]
for idx in range(1, len(all_size_list)):
scale = weight_init_std
if str(weight_init_std).lower() in ('relu', 'he'):
scale = np.sqrt(2.0 / all_size_list[idx - 1]) # ReLUを使う場合に推奨される初期値
elif str(weight_init_std).lower() in ('sigmoid', 'xavier'):
scale = np.sqrt(1.0 / all_size_list[idx - 1]) # sigmoidを使う場合に推奨される初期値
self.params['W' + str(idx)] = scale * np.random.randn(all_size_list[idx-1], all_size_list[idx])
self.params['b' + str(idx)] = np.zeros(all_size_list[idx])
def predict(self, x, train_flg=False):
for key, layer in self.layers.items():
if "Dropout" in key or "BatchNorm" in key:
x = layer.forward(x, train_flg)
else:
x = layer.forward(x)
return x
def loss(self, x, t, train_flg=False):
"""損失関数を求める
引数のxは入力データ、tは教師ラベル
"""
y = self.predict(x, train_flg)
weight_decay = 0
for idx in range(1, self.hidden_layer_num + 2):
W = self.params['W' + str(idx)]
weight_decay += 0.5 * self.weight_decay_lambda * np.sum(W**2)
return self.last_layer.forward(y, t) + weight_decay
def accuracy(self, X, T):
Y = self.predict(X, train_flg=False)
Y = np.argmax(Y, axis=1)
if T.ndim != 1 : T = np.argmax(T, axis=1)
accuracy = np.sum(Y == T) / float(X.shape[0])
return accuracy
def numerical_gradient(self, X, T):
"""勾配を求める(数値微分)
Parameters
----------
X : 入力データ
T : 教師ラベル
Returns
-------
各層の勾配を持ったディクショナリ変数
grads['W1']、grads['W2']、...は各層の重み
grads['b1']、grads['b2']、...は各層のバイアス
"""
loss_W = lambda W: self.loss(X, T, train_flg=True)
grads = {}
for idx in range(1, self.hidden_layer_num+2):
grads['W' + str(idx)] = numerical_gradient(loss_W, self.params['W' + str(idx)])
grads['b' + str(idx)] = numerical_gradient(loss_W, self.params['b' + str(idx)])
if self.use_batchnorm and idx != self.hidden_layer_num+1:
grads['gamma' + str(idx)] = numerical_gradient(loss_W, self.params['gamma' + str(idx)])
grads['beta' + str(idx)] = numerical_gradient(loss_W, self.params['beta' + str(idx)])
return grads
def gradient(self, x, t):
# forward
self.loss(x, t, train_flg=True)
# backward
dout = 1
dout = self.last_layer.backward(dout)
layers = list(self.layers.values())
layers.reverse()
for layer in layers:
dout = layer.backward(dout)
# 設定
grads = {}
for idx in range(1, self.hidden_layer_num+2):
grads['W' + str(idx)] = self.layers['Affine' + str(idx)].dW + self.weight_decay_lambda * self.params['W' + str(idx)]
grads['b' + str(idx)] = self.layers['Affine' + str(idx)].db
if self.use_batchnorm and idx != self.hidden_layer_num+1:
grads['gamma' + str(idx)] = self.layers['BatchNorm' + str(idx)].dgamma
grads['beta' + str(idx)] = self.layers['BatchNorm' + str(idx)].dbeta
return grads
class BatchNormalization:
"""
http://arxiv.org/abs/1502.03167
"""
def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
self.gamma = gamma
self.beta = beta
self.momentum = momentum
self.input_shape = None # Conv層の場合は4次元、全結合層の場合は2次元
# テスト時に使用する平均と分散
self.running_mean = running_mean
self.running_var = running_var
# backward時に使用する中間データ
self.batch_size = None
self.xc = None
self.std = None
self.dgamma = None
self.dbeta = None
def forward(self, x, train_flg=True):
self.input_shape = x.shape
if x.ndim != 2:
N, C, H, W = x.shape
x = x.reshape(N, -1)
out = self.__forward(x, train_flg)
return out.reshape(*self.input_shape)
def __forward(self, x, train_flg):
if self.running_mean is None:
N, D = x.shape
self.running_mean = np.zeros(D)
self.running_var = np.zeros(D)
if train_flg:
mu = x.mean(axis=0)
xc = x - mu
var = np.mean(xc**2, axis=0)
std = np.sqrt(var + 10e-7)
xn = xc / std
self.batch_size = x.shape[0]
self.xc = xc
self.xn = xn
self.std = std
self.running_mean = self.momentum * self.running_mean + (1-self.momentum) * mu
self.running_var = self.momentum * self.running_var + (1-self.momentum) * var
else:
xc = x - self.running_mean
xn = xc / ((np.sqrt(self.running_var + 10e-7)))
out = self.gamma * xn + self.beta
return out
def backward(self, dout):
if dout.ndim != 2:
N, C, H, W = dout.shape
dout = dout.reshape(N, -1)
dx = self.__backward(dout)
dx = dx.reshape(*self.input_shape)
return dx
def __backward(self, dout):
dbeta = dout.sum(axis=0)
dgamma = np.sum(self.xn * dout, axis=0)
dxn = self.gamma * dout
dxc = dxn / self.std
dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0)
dvar = 0.5 * dstd / self.std
dxc += (2.0 / self.batch_size) * self.xc * dvar
dmu = np.sum(dxc, axis=0)
dx = dxc - dmu / self.batch_size
self.dgamma = dgamma
self.dbeta = dbeta
return dx
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 学習データを削減
x_train = x_train[:1000]
t_train = t_train[:1000]
max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01
def __train(weight_init_std):
bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
weight_init_std=weight_init_std, use_batchnorm=True)
network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
weight_init_std=weight_init_std)
optimizer = SGD(lr=learning_rate)
train_acc_list = []
bn_train_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0
for i in range(1000000000):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
for _network in (bn_network, network):
grads = _network.gradient(x_batch, t_batch)
optimizer.update(_network.params, grads)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
bn_train_acc = bn_network.accuracy(x_train, t_train)
train_acc_list.append(train_acc)
bn_train_acc_list.append(bn_train_acc)
print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
return train_acc_list, bn_train_acc_list
# 3.グラフの描画==========
weight_scale_list = np.logspace(0, -4, num=16)
x = np.arange(max_epochs)
for i, w in enumerate(weight_scale_list):
print( "============== " + str(i+1) + "/16" + " ==============")
train_acc_list, bn_train_acc_list = __train(w)
plt.subplot(4,4,i+1)
plt.title("W:" + str(w))
if i == 15:
plt.plot(x, bn_train_acc_list, label='Batch Normalization', markevery=2)
plt.plot(x, train_acc_list, linestyle = "--", label='Normal(without BatchNorm)', markevery=2)
else:
plt.plot(x, bn_train_acc_list, markevery=2)
plt.plot(x, train_acc_list, linestyle="--", markevery=2)
plt.ylim(0, 1.0)
if i % 4:
plt.yticks([])
else:
plt.ylabel("accuracy")
if i < 12:
plt.xticks([])
else:
plt.xlabel("epochs")
plt.legend(loc='lower right')
plt.show()
============== 1/16 ============== epoch:0 | 0.116 - 0.113
.:56: RuntimeWarning: divide by zero encountered in log .:6: RuntimeWarning: invalid value encountered in less_equal .:91: RuntimeWarning: overflow encountered in square .:91: RuntimeWarning: invalid value encountered in double_scalars
epoch:1 | 0.097 - 0.099 epoch:2 | 0.097 - 0.126 epoch:3 | 0.097 - 0.151 epoch:4 | 0.097 - 0.17 epoch:5 | 0.097 - 0.182 epoch:6 | 0.097 - 0.198 epoch:7 | 0.097 - 0.214 epoch:8 | 0.097 - 0.236 epoch:9 | 0.097 - 0.251 epoch:10 | 0.097 - 0.274 epoch:11 | 0.097 - 0.298 epoch:12 | 0.097 - 0.31 epoch:13 | 0.097 - 0.323 epoch:14 | 0.097 - 0.345 epoch:15 | 0.097 - 0.353 epoch:16 | 0.097 - 0.36 epoch:17 | 0.097 - 0.366 epoch:18 | 0.097 - 0.386 epoch:19 | 0.097 - 0.399 ============== 2/16 ==============
/Users/kot/miniconda2/envs/python3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots. warnings.warn("No labelled objects found. "
epoch:0 | 0.099 - 0.11 epoch:1 | 0.097 - 0.134 epoch:2 | 0.097 - 0.169 epoch:3 | 0.097 - 0.172 epoch:4 | 0.097 - 0.197 epoch:5 | 0.097 - 0.223 epoch:6 | 0.097 - 0.237 epoch:7 | 0.097 - 0.252 epoch:8 | 0.097 - 0.269 epoch:9 | 0.097 - 0.293 epoch:10 | 0.097 - 0.317 epoch:11 | 0.097 - 0.34 epoch:12 | 0.097 - 0.357 epoch:13 | 0.097 - 0.379 epoch:14 | 0.097 - 0.398 epoch:15 | 0.097 - 0.419 epoch:16 | 0.097 - 0.449 epoch:17 | 0.097 - 0.45 epoch:18 | 0.097 - 0.46 epoch:19 | 0.097 - 0.479 ============== 3/16 ============== epoch:0 | 0.114 - 0.093 epoch:1 | 0.336 - 0.123 epoch:2 | 0.496 - 0.147 epoch:3 | 0.555 - 0.18 epoch:4 | 0.653 - 0.225 epoch:5 | 0.71 - 0.249 epoch:6 | 0.77 - 0.287 epoch:7 | 0.783 - 0.32 epoch:8 | 0.824 - 0.363 epoch:9 | 0.856 - 0.39 epoch:10 | 0.875 - 0.426 epoch:11 | 0.903 - 0.45 epoch:12 | 0.921 - 0.476 epoch:13 | 0.927 - 0.508 epoch:14 | 0.937 - 0.528 epoch:15 | 0.953 - 0.555 epoch:16 | 0.967 - 0.578 epoch:17 | 0.972 - 0.597 epoch:18 | 0.979 - 0.604 epoch:19 | 0.981 - 0.628 ============== 4/16 ============== epoch:0 | 0.117 - 0.095 epoch:1 | 0.238 - 0.097 epoch:2 | 0.375 - 0.16 epoch:3 | 0.464 - 0.238 epoch:4 | 0.569 - 0.305 epoch:5 | 0.606 - 0.365 epoch:6 | 0.665 - 0.434 epoch:7 | 0.709 - 0.497 epoch:8 | 0.728 - 0.542 epoch:9 | 0.759 - 0.578 epoch:10 | 0.779 - 0.61 epoch:11 | 0.793 - 0.634 epoch:12 | 0.818 - 0.663 epoch:13 | 0.825 - 0.691 epoch:14 | 0.843 - 0.716 epoch:15 | 0.854 - 0.728 epoch:16 | 0.861 - 0.747 epoch:17 | 0.859 - 0.759 epoch:18 | 0.865 - 0.786 epoch:19 | 0.877 - 0.796 ============== 5/16 ============== epoch:0 | 0.077 - 0.113 epoch:1 | 0.079 - 0.155 epoch:2 | 0.095 - 0.32 epoch:3 | 0.117 - 0.449 epoch:4 | 0.138 - 0.529 epoch:5 | 0.151 - 0.602 epoch:6 | 0.165 - 0.657 epoch:7 | 0.18 - 0.712 epoch:8 | 0.196 - 0.743 epoch:9 | 0.22 - 0.769 epoch:10 | 0.238 - 0.789 epoch:11 | 0.255 - 0.811 epoch:12 | 0.273 - 0.825 epoch:13 | 0.291 - 0.838 epoch:14 | 0.301 - 0.857 epoch:15 | 0.317 - 0.864 epoch:16 | 0.329 - 0.872 epoch:17 | 0.336 - 0.877 epoch:18 | 0.357 - 0.886 epoch:19 | 0.359 - 0.898 ============== 6/16 ============== epoch:0 | 0.133 - 0.089 epoch:1 | 0.111 - 0.203 epoch:2 | 0.118 - 0.374 epoch:3 | 0.117 - 0.542 epoch:4 | 0.117 - 0.684 epoch:5 | 0.117 - 0.747 epoch:6 | 0.118 - 0.79 epoch:7 | 0.117 - 0.813 epoch:8 | 0.104 - 0.839 epoch:9 | 0.119 - 0.847 epoch:10 | 0.1 - 0.865 epoch:11 | 0.116 - 0.878 epoch:12 | 0.118 - 0.892 epoch:13 | 0.124 - 0.905 epoch:14 | 0.115 - 0.919 epoch:15 | 0.115 - 0.922 epoch:16 | 0.118 - 0.929 epoch:17 | 0.131 - 0.937 epoch:18 | 0.12 - 0.941 epoch:19 | 0.117 - 0.945 ============== 7/16 ============== epoch:0 | 0.099 - 0.113 epoch:1 | 0.1 - 0.302 epoch:2 | 0.116 - 0.566 epoch:3 | 0.116 - 0.715 epoch:4 | 0.116 - 0.778 epoch:5 | 0.116 - 0.808 epoch:6 | 0.116 - 0.838 epoch:7 | 0.116 - 0.867 epoch:8 | 0.116 - 0.884 epoch:9 | 0.116 - 0.906 epoch:10 | 0.116 - 0.925 epoch:11 | 0.116 - 0.946 epoch:12 | 0.116 - 0.953 epoch:13 | 0.116 - 0.971 epoch:14 | 0.116 - 0.978 epoch:15 | 0.116 - 0.983 epoch:16 | 0.116 - 0.985 epoch:17 | 0.116 - 0.991 epoch:18 | 0.116 - 0.994 epoch:19 | 0.117 - 0.994 ============== 8/16 ============== epoch:0 | 0.117 - 0.122 epoch:1 | 0.117 - 0.561 epoch:2 | 0.099 - 0.755 epoch:3 | 0.117 - 0.814 epoch:4 | 0.099 - 0.849 epoch:5 | 0.117 - 0.887 epoch:6 | 0.117 - 0.903 epoch:7 | 0.117 - 0.917 epoch:8 | 0.117 - 0.946 epoch:9 | 0.117 - 0.968 epoch:10 | 0.117 - 0.977 epoch:11 | 0.117 - 0.983 epoch:12 | 0.117 - 0.989 epoch:13 | 0.116 - 0.994 epoch:14 | 0.116 - 0.996 epoch:15 | 0.116 - 0.995 epoch:16 | 0.116 - 0.998 epoch:17 | 0.116 - 0.997 epoch:18 | 0.116 - 0.999 epoch:19 | 0.116 - 0.999 ============== 9/16 ============== epoch:0 | 0.094 - 0.129 epoch:1 | 0.117 - 0.416 epoch:2 | 0.117 - 0.658 epoch:3 | 0.117 - 0.867 epoch:4 | 0.117 - 0.921 epoch:5 | 0.117 - 0.946 epoch:6 | 0.117 - 0.965 epoch:7 | 0.117 - 0.973 epoch:8 | 0.117 - 0.986 epoch:9 | 0.117 - 0.992 epoch:10 | 0.117 - 0.995 epoch:11 | 0.117 - 0.997 epoch:12 | 0.117 - 0.996 epoch:13 | 0.117 - 0.998 epoch:14 | 0.117 - 0.998 epoch:15 | 0.117 - 0.998 epoch:16 | 0.117 - 0.999 epoch:17 | 0.117 - 0.999 epoch:18 | 0.117 - 0.999 epoch:19 | 0.117 - 0.999 ============== 10/16 ============== epoch:0 | 0.117 - 0.156 epoch:1 | 0.116 - 0.581 epoch:2 | 0.117 - 0.682 epoch:3 | 0.117 - 0.797 epoch:4 | 0.117 - 0.868 epoch:5 | 0.117 - 0.912 epoch:6 | 0.117 - 0.845 epoch:7 | 0.117 - 0.864 epoch:8 | 0.117 - 0.882 epoch:9 | 0.117 - 0.889 epoch:10 | 0.117 - 0.902 epoch:11 | 0.117 - 0.905 epoch:12 | 0.117 - 0.936 epoch:13 | 0.117 - 0.969 epoch:14 | 0.117 - 0.984 epoch:15 | 0.117 - 0.989 epoch:16 | 0.117 - 0.989 epoch:17 | 0.117 - 0.977 epoch:18 | 0.117 - 0.995 epoch:19 | 0.117 - 0.995 ============== 11/16 ============== epoch:0 | 0.097 - 0.13 epoch:1 | 0.116 - 0.567 epoch:2 | 0.116 - 0.778 epoch:3 | 0.116 - 0.711 epoch:4 | 0.116 - 0.837 epoch:5 | 0.116 - 0.833 epoch:6 | 0.116 - 0.881 epoch:7 | 0.116 - 0.824 epoch:8 | 0.116 - 0.877 epoch:9 | 0.116 - 0.908 epoch:10 | 0.116 - 0.921 epoch:11 | 0.116 - 0.98 epoch:12 | 0.116 - 0.989 epoch:13 | 0.116 - 0.99 epoch:14 | 0.116 - 0.99 epoch:15 | 0.116 - 0.983 epoch:16 | 0.116 - 0.992 epoch:17 | 0.116 - 0.992 epoch:18 | 0.116 - 0.994 epoch:19 | 0.116 - 0.995 ============== 12/16 ============== epoch:0 | 0.094 - 0.122 epoch:1 | 0.117 - 0.534 epoch:2 | 0.117 - 0.632 epoch:3 | 0.117 - 0.656 epoch:4 | 0.117 - 0.691 epoch:5 | 0.117 - 0.791 epoch:6 | 0.117 - 0.766 epoch:7 | 0.117 - 0.835 epoch:8 | 0.117 - 0.834 epoch:9 | 0.117 - 0.784 epoch:10 | 0.117 - 0.875 epoch:11 | 0.117 - 0.826 epoch:12 | 0.117 - 0.873 epoch:13 | 0.117 - 0.894 epoch:14 | 0.117 - 0.897 epoch:15 | 0.117 - 0.882 epoch:16 | 0.117 - 0.895 epoch:17 | 0.117 - 0.899 epoch:18 | 0.117 - 0.901 epoch:19 | 0.117 - 0.902 ============== 13/16 ============== epoch:0 | 0.087 - 0.156 epoch:1 | 0.117 - 0.476 epoch:2 | 0.117 - 0.505 epoch:3 | 0.105 - 0.372 epoch:4 | 0.105 - 0.57 epoch:5 | 0.117 - 0.56 epoch:6 | 0.117 - 0.57 epoch:7 | 0.117 - 0.589 epoch:8 | 0.116 - 0.607 epoch:9 | 0.116 - 0.605 epoch:10 | 0.116 - 0.597 epoch:11 | 0.116 - 0.618 epoch:12 | 0.116 - 0.676 epoch:13 | 0.117 - 0.701 epoch:14 | 0.116 - 0.692 epoch:15 | 0.117 - 0.687 epoch:16 | 0.117 - 0.704 epoch:17 | 0.117 - 0.695 epoch:18 | 0.117 - 0.691 epoch:19 | 0.117 - 0.7 ============== 14/16 ============== epoch:0 | 0.116 - 0.14 epoch:1 | 0.116 - 0.391 epoch:2 | 0.116 - 0.433 epoch:3 | 0.116 - 0.503 epoch:4 | 0.116 - 0.564 epoch:5 | 0.116 - 0.561 epoch:6 | 0.116 - 0.542 epoch:7 | 0.116 - 0.587 epoch:8 | 0.116 - 0.594 epoch:9 | 0.116 - 0.597 epoch:10 | 0.116 - 0.599 epoch:11 | 0.116 - 0.6 epoch:12 | 0.116 - 0.608 epoch:13 | 0.116 - 0.605 epoch:14 | 0.116 - 0.608 epoch:15 | 0.116 - 0.603 epoch:16 | 0.116 - 0.601 epoch:17 | 0.116 - 0.609 epoch:18 | 0.116 - 0.612 epoch:19 | 0.116 - 0.61 ============== 15/16 ============== epoch:0 | 0.116 - 0.097 epoch:1 | 0.116 - 0.317 epoch:2 | 0.116 - 0.325 epoch:3 | 0.116 - 0.324 epoch:4 | 0.116 - 0.325 epoch:5 | 0.116 - 0.358 epoch:6 | 0.116 - 0.408 epoch:7 | 0.116 - 0.417 epoch:8 | 0.116 - 0.424 epoch:9 | 0.116 - 0.422 epoch:10 | 0.116 - 0.401 epoch:11 | 0.116 - 0.432 epoch:12 | 0.116 - 0.429 epoch:13 | 0.116 - 0.374 epoch:14 | 0.116 - 0.426 epoch:15 | 0.116 - 0.432 epoch:16 | 0.116 - 0.432 epoch:17 | 0.116 - 0.432 epoch:18 | 0.116 - 0.429 epoch:19 | 0.116 - 0.428 ============== 16/16 ============== epoch:0 | 0.097 - 0.097 epoch:1 | 0.116 - 0.159 epoch:2 | 0.117 - 0.412 epoch:3 | 0.116 - 0.401 epoch:4 | 0.116 - 0.413 epoch:5 | 0.116 - 0.417 epoch:6 | 0.116 - 0.414 epoch:7 | 0.116 - 0.416 epoch:8 | 0.116 - 0.428 epoch:9 | 0.116 - 0.432 epoch:10 | 0.116 - 0.432 epoch:11 | 0.116 - 0.432 epoch:12 | 0.116 - 0.519 epoch:13 | 0.116 - 0.519 epoch:14 | 0.116 - 0.517 epoch:15 | 0.116 - 0.526 epoch:16 | 0.116 - 0.523 epoch:17 | 0.116 - 0.527 epoch:18 | 0.116 - 0.53 epoch:19 | 0.117 - 0.529
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 過学習を再現するために、学習データを削減
x_train = x_train[:300]
t_train = t_train[:300]
# weight decay(荷重減衰)の設定 =======================
#weight_decay_lambda = 0 # weight decayを使用しない場合
weight_decay_lambda = 0.1
# ====================================================
network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100], output_size=10,
weight_decay_lambda=weight_decay_lambda)
optimizer = SGD(lr=0.01)
max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0
for i in range(1000000000):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
grads = network.gradient(x_batch, t_batch)
optimizer.update(network.params, grads)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print("epoch:" + str(epoch_cnt) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
# 3.グラフの描画==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
epoch:0, train acc:0.106666666667, test acc:0.112 epoch:1, train acc:0.113333333333, test acc:0.1232 epoch:2, train acc:0.15, test acc:0.1383 epoch:3, train acc:0.163333333333, test acc:0.1517 epoch:4, train acc:0.193333333333, test acc:0.166 epoch:5, train acc:0.236666666667, test acc:0.1995 epoch:6, train acc:0.263333333333, test acc:0.2215 epoch:7, train acc:0.313333333333, test acc:0.2465 epoch:8, train acc:0.35, test acc:0.2623 epoch:9, train acc:0.39, test acc:0.2796 epoch:10, train acc:0.413333333333, test acc:0.2973 epoch:11, train acc:0.416666666667, test acc:0.3035 epoch:12, train acc:0.426666666667, test acc:0.3084 epoch:13, train acc:0.463333333333, test acc:0.3176 epoch:14, train acc:0.493333333333, test acc:0.3334 epoch:15, train acc:0.493333333333, test acc:0.3368 epoch:16, train acc:0.493333333333, test acc:0.343 epoch:17, train acc:0.5, test acc:0.3451 epoch:18, train acc:0.53, test acc:0.3601 epoch:19, train acc:0.526666666667, test acc:0.3712 epoch:20, train acc:0.52, test acc:0.3679 epoch:21, train acc:0.543333333333, test acc:0.3823 epoch:22, train acc:0.54, test acc:0.3897 epoch:23, train acc:0.553333333333, test acc:0.3967 epoch:24, train acc:0.556666666667, test acc:0.4065 epoch:25, train acc:0.563333333333, test acc:0.4194 epoch:26, train acc:0.576666666667, test acc:0.4337 epoch:27, train acc:0.583333333333, test acc:0.4293 epoch:28, train acc:0.586666666667, test acc:0.4332 epoch:29, train acc:0.573333333333, test acc:0.4335 epoch:30, train acc:0.603333333333, test acc:0.4406 epoch:31, train acc:0.603333333333, test acc:0.4386 epoch:32, train acc:0.6, test acc:0.4418 epoch:33, train acc:0.6, test acc:0.4485 epoch:34, train acc:0.593333333333, test acc:0.4542 epoch:35, train acc:0.626666666667, test acc:0.4725 epoch:36, train acc:0.67, test acc:0.5008 epoch:37, train acc:0.663333333333, test acc:0.5165 epoch:38, train acc:0.656666666667, test acc:0.5096 epoch:39, train acc:0.65, test acc:0.5134 epoch:40, train acc:0.653333333333, test acc:0.5166 epoch:41, train acc:0.69, test acc:0.5334 epoch:42, train acc:0.7, test acc:0.5327 epoch:43, train acc:0.703333333333, test acc:0.5421 epoch:44, train acc:0.703333333333, test acc:0.5277 epoch:45, train acc:0.7, test acc:0.5426 epoch:46, train acc:0.706666666667, test acc:0.5558 epoch:47, train acc:0.69, test acc:0.555 epoch:48, train acc:0.71, test acc:0.5631 epoch:49, train acc:0.716666666667, test acc:0.561 epoch:50, train acc:0.723333333333, test acc:0.5544 epoch:51, train acc:0.74, test acc:0.5658 epoch:52, train acc:0.753333333333, test acc:0.5852 epoch:53, train acc:0.76, test acc:0.5933 epoch:54, train acc:0.73, test acc:0.5916 epoch:55, train acc:0.736666666667, test acc:0.5902 epoch:56, train acc:0.74, test acc:0.5681 epoch:57, train acc:0.736666666667, test acc:0.5757 epoch:58, train acc:0.753333333333, test acc:0.5897 epoch:59, train acc:0.763333333333, test acc:0.588 epoch:60, train acc:0.763333333333, test acc:0.6019 epoch:61, train acc:0.753333333333, test acc:0.5983 epoch:62, train acc:0.746666666667, test acc:0.5884 epoch:63, train acc:0.77, test acc:0.611 epoch:64, train acc:0.77, test acc:0.6102 epoch:65, train acc:0.763333333333, test acc:0.6084 epoch:66, train acc:0.773333333333, test acc:0.6055 epoch:67, train acc:0.766666666667, test acc:0.6016 epoch:68, train acc:0.756666666667, test acc:0.5974 epoch:69, train acc:0.786666666667, test acc:0.6124 epoch:70, train acc:0.793333333333, test acc:0.616 epoch:71, train acc:0.783333333333, test acc:0.6179 epoch:72, train acc:0.793333333333, test acc:0.6107 epoch:73, train acc:0.81, test acc:0.64 epoch:74, train acc:0.793333333333, test acc:0.6308 epoch:75, train acc:0.793333333333, test acc:0.6415 epoch:76, train acc:0.786666666667, test acc:0.651 epoch:77, train acc:0.816666666667, test acc:0.6498 epoch:78, train acc:0.796666666667, test acc:0.6312 epoch:79, train acc:0.79, test acc:0.6216 epoch:80, train acc:0.78, test acc:0.6201 epoch:81, train acc:0.813333333333, test acc:0.6437 epoch:82, train acc:0.796666666667, test acc:0.6366 epoch:83, train acc:0.81, test acc:0.6619 epoch:84, train acc:0.826666666667, test acc:0.6603 epoch:85, train acc:0.836666666667, test acc:0.6599 epoch:86, train acc:0.823333333333, test acc:0.6513 epoch:87, train acc:0.813333333333, test acc:0.6485 epoch:88, train acc:0.816666666667, test acc:0.6616 epoch:89, train acc:0.84, test acc:0.663 epoch:90, train acc:0.846666666667, test acc:0.6699 epoch:91, train acc:0.833333333333, test acc:0.6716 epoch:92, train acc:0.813333333333, test acc:0.6567 epoch:93, train acc:0.82, test acc:0.6609 epoch:94, train acc:0.84, test acc:0.679 epoch:95, train acc:0.823333333333, test acc:0.6676 epoch:96, train acc:0.843333333333, test acc:0.684 epoch:97, train acc:0.856666666667, test acc:0.6867 epoch:98, train acc:0.85, test acc:0.688 epoch:99, train acc:0.826666666667, test acc:0.6773 epoch:100, train acc:0.833333333333, test acc:0.6796 epoch:101, train acc:0.84, test acc:0.6829 epoch:102, train acc:0.846666666667, test acc:0.6797 epoch:103, train acc:0.84, test acc:0.6783 epoch:104, train acc:0.85, test acc:0.6846 epoch:105, train acc:0.826666666667, test acc:0.6719 epoch:106, train acc:0.843333333333, test acc:0.6875 epoch:107, train acc:0.853333333333, test acc:0.6845 epoch:108, train acc:0.856666666667, test acc:0.6875 epoch:109, train acc:0.86, test acc:0.6922 epoch:110, train acc:0.843333333333, test acc:0.6956 epoch:111, train acc:0.86, test acc:0.6957 epoch:112, train acc:0.843333333333, test acc:0.7025 epoch:113, train acc:0.863333333333, test acc:0.71 epoch:114, train acc:0.86, test acc:0.6901 epoch:115, train acc:0.853333333333, test acc:0.6846 epoch:116, train acc:0.853333333333, test acc:0.69 epoch:117, train acc:0.863333333333, test acc:0.6928 epoch:118, train acc:0.856666666667, test acc:0.6812 epoch:119, train acc:0.85, test acc:0.6987 epoch:120, train acc:0.86, test acc:0.6972 epoch:121, train acc:0.856666666667, test acc:0.6938 epoch:122, train acc:0.87, test acc:0.7056 epoch:123, train acc:0.87, test acc:0.7052 epoch:124, train acc:0.856666666667, test acc:0.7109 epoch:125, train acc:0.86, test acc:0.6864 epoch:126, train acc:0.86, test acc:0.6993 epoch:127, train acc:0.856666666667, test acc:0.7031 epoch:128, train acc:0.856666666667, test acc:0.6939 epoch:129, train acc:0.866666666667, test acc:0.7051 epoch:130, train acc:0.86, test acc:0.7166 epoch:131, train acc:0.863333333333, test acc:0.6946 epoch:132, train acc:0.85, test acc:0.7039 epoch:133, train acc:0.846666666667, test acc:0.7078 epoch:134, train acc:0.85, test acc:0.7144 epoch:135, train acc:0.863333333333, test acc:0.7085 epoch:136, train acc:0.866666666667, test acc:0.7034 epoch:137, train acc:0.87, test acc:0.7016 epoch:138, train acc:0.876666666667, test acc:0.7134 epoch:139, train acc:0.866666666667, test acc:0.7087 epoch:140, train acc:0.86, test acc:0.7056 epoch:141, train acc:0.856666666667, test acc:0.7134 epoch:142, train acc:0.876666666667, test acc:0.7079 epoch:143, train acc:0.886666666667, test acc:0.7114 epoch:144, train acc:0.846666666667, test acc:0.7071 epoch:145, train acc:0.866666666667, test acc:0.7211 epoch:146, train acc:0.863333333333, test acc:0.715 epoch:147, train acc:0.866666666667, test acc:0.712 epoch:148, train acc:0.86, test acc:0.7155 epoch:149, train acc:0.85, test acc:0.721 epoch:150, train acc:0.87, test acc:0.7121 epoch:151, train acc:0.866666666667, test acc:0.7234 epoch:152, train acc:0.89, test acc:0.7236 epoch:153, train acc:0.88, test acc:0.7244 epoch:154, train acc:0.86, test acc:0.7216 epoch:155, train acc:0.836666666667, test acc:0.7133 epoch:156, train acc:0.856666666667, test acc:0.7174 epoch:157, train acc:0.87, test acc:0.7129 epoch:158, train acc:0.863333333333, test acc:0.7077 epoch:159, train acc:0.866666666667, test acc:0.7092 epoch:160, train acc:0.88, test acc:0.7194 epoch:161, train acc:0.86, test acc:0.7162 epoch:162, train acc:0.856666666667, test acc:0.6953 epoch:163, train acc:0.863333333333, test acc:0.7132 epoch:164, train acc:0.876666666667, test acc:0.7096 epoch:165, train acc:0.87, test acc:0.7134 epoch:166, train acc:0.853333333333, test acc:0.7106 epoch:167, train acc:0.86, test acc:0.7119 epoch:168, train acc:0.88, test acc:0.7107 epoch:169, train acc:0.843333333333, test acc:0.7174 epoch:170, train acc:0.873333333333, test acc:0.7188 epoch:171, train acc:0.876666666667, test acc:0.7214 epoch:172, train acc:0.866666666667, test acc:0.7121 epoch:173, train acc:0.866666666667, test acc:0.718 epoch:174, train acc:0.886666666667, test acc:0.7193 epoch:175, train acc:0.873333333333, test acc:0.7247 epoch:176, train acc:0.863333333333, test acc:0.7228 epoch:177, train acc:0.876666666667, test acc:0.71 epoch:178, train acc:0.89, test acc:0.7136 epoch:179, train acc:0.866666666667, test acc:0.7158 epoch:180, train acc:0.886666666667, test acc:0.7179 epoch:181, train acc:0.863333333333, test acc:0.7185 epoch:182, train acc:0.87, test acc:0.7129 epoch:183, train acc:0.863333333333, test acc:0.7216 epoch:184, train acc:0.88, test acc:0.7206 epoch:185, train acc:0.886666666667, test acc:0.7309 epoch:186, train acc:0.893333333333, test acc:0.7189 epoch:187, train acc:0.883333333333, test acc:0.7161 epoch:188, train acc:0.866666666667, test acc:0.725 epoch:189, train acc:0.873333333333, test acc:0.7204 epoch:190, train acc:0.866666666667, test acc:0.7208 epoch:191, train acc:0.88, test acc:0.7211 epoch:192, train acc:0.88, test acc:0.7156 epoch:193, train acc:0.883333333333, test acc:0.7214 epoch:194, train acc:0.873333333333, test acc:0.7292 epoch:195, train acc:0.89, test acc:0.7234 epoch:196, train acc:0.886666666667, test acc:0.7268 epoch:197, train acc:0.886666666667, test acc:0.7214 epoch:198, train acc:0.883333333333, test acc:0.7242 epoch:199, train acc:0.876666666667, test acc:0.721 epoch:200, train acc:0.87, test acc:0.7165
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 過学習を再現するために、学習データを削減
x_train = x_train[:300]
t_train = t_train[:300]
# weight decay(荷重減衰)の設定 =======================
#weight_decay_lambda = 0 # weight decayを使用しない場合
weight_decay_lambda = 0.1
# ====================================================
network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100], output_size=10,
weight_decay_lambda=weight_decay_lambda)
optimizer = SGD(lr=0.01)
max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0
for i in range(1000000000):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
grads = network.gradient(x_batch, t_batch)
optimizer.update(network.params, grads)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print("epoch:" + str(epoch_cnt) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
# 3.グラフの描画==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
epoch:0, train acc:0.0833333333333, test acc:0.1107 epoch:1, train acc:0.103333333333, test acc:0.1234 epoch:2, train acc:0.126666666667, test acc:0.1395 epoch:3, train acc:0.163333333333, test acc:0.1542 epoch:4, train acc:0.22, test acc:0.1746 epoch:5, train acc:0.24, test acc:0.1998 epoch:6, train acc:0.276666666667, test acc:0.2196 epoch:7, train acc:0.296666666667, test acc:0.2362 epoch:8, train acc:0.326666666667, test acc:0.2536 epoch:9, train acc:0.373333333333, test acc:0.2771 epoch:10, train acc:0.37, test acc:0.2845 epoch:11, train acc:0.37, test acc:0.303 epoch:12, train acc:0.4, test acc:0.3285 epoch:13, train acc:0.453333333333, test acc:0.341 epoch:14, train acc:0.47, test acc:0.3554 epoch:15, train acc:0.5, test acc:0.3704 epoch:16, train acc:0.513333333333, test acc:0.3856 epoch:17, train acc:0.536666666667, test acc:0.4005 epoch:18, train acc:0.56, test acc:0.407 epoch:19, train acc:0.58, test acc:0.4176 epoch:20, train acc:0.566666666667, test acc:0.4143 epoch:21, train acc:0.553333333333, test acc:0.4161 epoch:22, train acc:0.593333333333, test acc:0.4325 epoch:23, train acc:0.583333333333, test acc:0.4383 epoch:24, train acc:0.583333333333, test acc:0.4475 epoch:25, train acc:0.61, test acc:0.4485 epoch:26, train acc:0.633333333333, test acc:0.4598 epoch:27, train acc:0.646666666667, test acc:0.4768 epoch:28, train acc:0.68, test acc:0.4813 epoch:29, train acc:0.663333333333, test acc:0.4908 epoch:30, train acc:0.68, test acc:0.4926 epoch:31, train acc:0.676666666667, test acc:0.4929 epoch:32, train acc:0.7, test acc:0.503 epoch:33, train acc:0.696666666667, test acc:0.5116 epoch:34, train acc:0.7, test acc:0.5141 epoch:35, train acc:0.693333333333, test acc:0.5026 epoch:36, train acc:0.73, test acc:0.5204 epoch:37, train acc:0.72, test acc:0.529 epoch:38, train acc:0.746666666667, test acc:0.5226 epoch:39, train acc:0.723333333333, test acc:0.5358 epoch:40, train acc:0.76, test acc:0.5372 epoch:41, train acc:0.77, test acc:0.5668 epoch:42, train acc:0.766666666667, test acc:0.5584 epoch:43, train acc:0.76, test acc:0.5628 epoch:44, train acc:0.78, test acc:0.5625 epoch:45, train acc:0.79, test acc:0.5857 epoch:46, train acc:0.786666666667, test acc:0.574 epoch:47, train acc:0.803333333333, test acc:0.5899 epoch:48, train acc:0.81, test acc:0.5975 epoch:49, train acc:0.806666666667, test acc:0.5863 epoch:50, train acc:0.816666666667, test acc:0.5955 epoch:51, train acc:0.8, test acc:0.5803 epoch:52, train acc:0.803333333333, test acc:0.5986 epoch:53, train acc:0.81, test acc:0.5909 epoch:54, train acc:0.803333333333, test acc:0.5842 epoch:55, train acc:0.813333333333, test acc:0.601 epoch:56, train acc:0.783333333333, test acc:0.601 epoch:57, train acc:0.806666666667, test acc:0.6134 epoch:58, train acc:0.803333333333, test acc:0.606 epoch:59, train acc:0.8, test acc:0.6082 epoch:60, train acc:0.793333333333, test acc:0.6104 epoch:61, train acc:0.83, test acc:0.6156 epoch:62, train acc:0.833333333333, test acc:0.6299 epoch:63, train acc:0.823333333333, test acc:0.6293 epoch:64, train acc:0.836666666667, test acc:0.633 epoch:65, train acc:0.836666666667, test acc:0.6321 epoch:66, train acc:0.84, test acc:0.6443 epoch:67, train acc:0.846666666667, test acc:0.6463 epoch:68, train acc:0.836666666667, test acc:0.6346 epoch:69, train acc:0.83, test acc:0.6478 epoch:70, train acc:0.856666666667, test acc:0.655 epoch:71, train acc:0.823333333333, test acc:0.6289 epoch:72, train acc:0.823333333333, test acc:0.6406 epoch:73, train acc:0.85, test acc:0.6507 epoch:74, train acc:0.843333333333, test acc:0.6457 epoch:75, train acc:0.843333333333, test acc:0.644 epoch:76, train acc:0.843333333333, test acc:0.6492 epoch:77, train acc:0.85, test acc:0.6501 epoch:78, train acc:0.846666666667, test acc:0.6452 epoch:79, train acc:0.84, test acc:0.6521 epoch:80, train acc:0.846666666667, test acc:0.6546 epoch:81, train acc:0.85, test acc:0.6555 epoch:82, train acc:0.846666666667, test acc:0.6484 epoch:83, train acc:0.85, test acc:0.6574 epoch:84, train acc:0.863333333333, test acc:0.6652 epoch:85, train acc:0.85, test acc:0.6531 epoch:86, train acc:0.853333333333, test acc:0.6587 epoch:87, train acc:0.853333333333, test acc:0.6551 epoch:88, train acc:0.846666666667, test acc:0.6633 epoch:89, train acc:0.843333333333, test acc:0.6629 epoch:90, train acc:0.863333333333, test acc:0.6715 epoch:91, train acc:0.86, test acc:0.6703 epoch:92, train acc:0.87, test acc:0.6735 epoch:93, train acc:0.866666666667, test acc:0.6749 epoch:94, train acc:0.86, test acc:0.6734 epoch:95, train acc:0.876666666667, test acc:0.6834 epoch:96, train acc:0.87, test acc:0.6795 epoch:97, train acc:0.86, test acc:0.6806 epoch:98, train acc:0.856666666667, test acc:0.6801 epoch:99, train acc:0.86, test acc:0.6792 epoch:100, train acc:0.856666666667, test acc:0.6765 epoch:101, train acc:0.856666666667, test acc:0.6754 epoch:102, train acc:0.886666666667, test acc:0.683 epoch:103, train acc:0.87, test acc:0.6883 epoch:104, train acc:0.876666666667, test acc:0.6894 epoch:105, train acc:0.863333333333, test acc:0.6859 epoch:106, train acc:0.88, test acc:0.6853 epoch:107, train acc:0.86, test acc:0.6763 epoch:108, train acc:0.9, test acc:0.6943 epoch:109, train acc:0.88, test acc:0.6947 epoch:110, train acc:0.896666666667, test acc:0.6932 epoch:111, train acc:0.883333333333, test acc:0.6958 epoch:112, train acc:0.886666666667, test acc:0.6934 epoch:113, train acc:0.866666666667, test acc:0.687 epoch:114, train acc:0.87, test acc:0.6848 epoch:115, train acc:0.88, test acc:0.688 epoch:116, train acc:0.87, test acc:0.6863 epoch:117, train acc:0.87, test acc:0.6876 epoch:118, train acc:0.87, test acc:0.6945 epoch:119, train acc:0.893333333333, test acc:0.7008 epoch:120, train acc:0.87, test acc:0.695 epoch:121, train acc:0.893333333333, test acc:0.6987 epoch:122, train acc:0.886666666667, test acc:0.7031 epoch:123, train acc:0.893333333333, test acc:0.706 epoch:124, train acc:0.903333333333, test acc:0.709 epoch:125, train acc:0.906666666667, test acc:0.7057 epoch:126, train acc:0.906666666667, test acc:0.7029 epoch:127, train acc:0.91, test acc:0.702 epoch:128, train acc:0.89, test acc:0.7062 epoch:129, train acc:0.88, test acc:0.7019 epoch:130, train acc:0.87, test acc:0.6898 epoch:131, train acc:0.883333333333, test acc:0.7008 epoch:132, train acc:0.903333333333, test acc:0.7058 epoch:133, train acc:0.89, test acc:0.7036 epoch:134, train acc:0.89, test acc:0.7054 epoch:135, train acc:0.896666666667, test acc:0.7051 epoch:136, train acc:0.883333333333, test acc:0.6928 epoch:137, train acc:0.9, test acc:0.7094 epoch:138, train acc:0.9, test acc:0.7059 epoch:139, train acc:0.9, test acc:0.7027 epoch:140, train acc:0.9, test acc:0.706 epoch:141, train acc:0.91, test acc:0.7118 epoch:142, train acc:0.9, test acc:0.7092 epoch:143, train acc:0.916666666667, test acc:0.715 epoch:144, train acc:0.896666666667, test acc:0.7048 epoch:145, train acc:0.903333333333, test acc:0.7132 epoch:146, train acc:0.886666666667, test acc:0.7035 epoch:147, train acc:0.89, test acc:0.7119 epoch:148, train acc:0.873333333333, test acc:0.7037 epoch:149, train acc:0.876666666667, test acc:0.7073 epoch:150, train acc:0.896666666667, test acc:0.7106 epoch:151, train acc:0.906666666667, test acc:0.7161 epoch:152, train acc:0.906666666667, test acc:0.7198 epoch:153, train acc:0.896666666667, test acc:0.7198 epoch:154, train acc:0.9, test acc:0.7169 epoch:155, train acc:0.893333333333, test acc:0.7041 epoch:156, train acc:0.886666666667, test acc:0.7051 epoch:157, train acc:0.906666666667, test acc:0.7105 epoch:158, train acc:0.893333333333, test acc:0.7107 epoch:159, train acc:0.883333333333, test acc:0.7071 epoch:160, train acc:0.906666666667, test acc:0.7171 epoch:161, train acc:0.92, test acc:0.7195 epoch:162, train acc:0.906666666667, test acc:0.7189 epoch:163, train acc:0.923333333333, test acc:0.7205 epoch:164, train acc:0.9, test acc:0.7169 epoch:165, train acc:0.9, test acc:0.7149 epoch:166, train acc:0.9, test acc:0.7175 epoch:167, train acc:0.903333333333, test acc:0.7171 epoch:168, train acc:0.916666666667, test acc:0.718 epoch:169, train acc:0.9, test acc:0.7137 epoch:170, train acc:0.903333333333, test acc:0.7149 epoch:171, train acc:0.9, test acc:0.7157 epoch:172, train acc:0.9, test acc:0.7109 epoch:173, train acc:0.903333333333, test acc:0.7148 epoch:174, train acc:0.913333333333, test acc:0.7166 epoch:175, train acc:0.903333333333, test acc:0.7184 epoch:176, train acc:0.903333333333, test acc:0.7204 epoch:177, train acc:0.903333333333, test acc:0.719 epoch:178, train acc:0.91, test acc:0.7156 epoch:179, train acc:0.903333333333, test acc:0.7194 epoch:180, train acc:0.923333333333, test acc:0.7168 epoch:181, train acc:0.923333333333, test acc:0.7216 epoch:182, train acc:0.906666666667, test acc:0.7221 epoch:183, train acc:0.906666666667, test acc:0.7187 epoch:184, train acc:0.903333333333, test acc:0.7199 epoch:185, train acc:0.893333333333, test acc:0.7112 epoch:186, train acc:0.886666666667, test acc:0.7073 epoch:187, train acc:0.9, test acc:0.7131 epoch:188, train acc:0.906666666667, test acc:0.7165 epoch:189, train acc:0.886666666667, test acc:0.7103 epoch:190, train acc:0.896666666667, test acc:0.712 epoch:191, train acc:0.9, test acc:0.7185 epoch:192, train acc:0.906666666667, test acc:0.7156 epoch:193, train acc:0.906666666667, test acc:0.7144 epoch:194, train acc:0.89, test acc:0.7088 epoch:195, train acc:0.903333333333, test acc:0.7187 epoch:196, train acc:0.926666666667, test acc:0.722 epoch:197, train acc:0.89, test acc:0.715 epoch:198, train acc:0.91, test acc:0.7232 epoch:199, train acc:0.906666666667, test acc:0.7255 epoch:200, train acc:0.916666666667, test acc:0.7238
class Dropout:
"""
http://arxiv.org/abs/1207.0580
"""
def __init__(self, dropout_ratio=0.5):
self.dropout_ratio = dropout_ratio
self.mask = None
def forward(self, x, train_flg=True):
if train_flg:
self.mask = np.random.rand(*x.shape) > self.dropout_ratio
return x * self.mask
else:
return x * (1.0 - self.dropout_ratio)
def backward(self, dout):
return dout * self.mask
class Trainer:
"""ニューラルネットの訓練を行うクラス
"""
def __init__(self, network, x_train, t_train, x_test, t_test,
epochs=20, mini_batch_size=100,
optimizer='SGD', optimizer_param={'lr':0.01},
evaluate_sample_num_per_epoch=None, verbose=True):
self.network = network
self.verbose = verbose
self.x_train = x_train
self.t_train = t_train
self.x_test = x_test
self.t_test = t_test
self.epochs = epochs
self.batch_size = mini_batch_size
self.evaluate_sample_num_per_epoch = evaluate_sample_num_per_epoch
# optimzer
optimizer_class_dict = {'sgd':SGD, 'momentum':Momentum, 'nesterov':Nesterov,
'adagrad':AdaGrad, 'rmsprpo':RMSprop, 'adam':Adam}
self.optimizer = optimizer_class_dict[optimizer.lower()](**optimizer_param)
self.train_size = x_train.shape[0]
self.iter_per_epoch = max(self.train_size / mini_batch_size, 1)
self.max_iter = int(epochs * self.iter_per_epoch)
self.current_iter = 0
self.current_epoch = 0
self.train_loss_list = []
self.train_acc_list = []
self.test_acc_list = []
def train_step(self):
batch_mask = np.random.choice(self.train_size, self.batch_size)
x_batch = self.x_train[batch_mask]
t_batch = self.t_train[batch_mask]
grads = self.network.gradient(x_batch, t_batch)
self.optimizer.update(self.network.params, grads)
loss = self.network.loss(x_batch, t_batch)
self.train_loss_list.append(loss)
if self.verbose: print("train loss:" + str(loss))
if self.current_iter % self.iter_per_epoch == 0:
self.current_epoch += 1
x_train_sample, t_train_sample = self.x_train, self.t_train
x_test_sample, t_test_sample = self.x_test, self.t_test
if not self.evaluate_sample_num_per_epoch is None:
t = self.evaluate_sample_num_per_epoch
x_train_sample, t_train_sample = self.x_train[:t], self.t_train[:t]
x_test_sample, t_test_sample = self.x_test[:t], self.t_test[:t]
train_acc = self.network.accuracy(x_train_sample, t_train_sample)
test_acc = self.network.accuracy(x_test_sample, t_test_sample)
self.train_acc_list.append(train_acc)
self.test_acc_list.append(test_acc)
if self.verbose: print("=== epoch:" + str(self.current_epoch) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc) + " ===")
self.current_iter += 1
def train(self):
for i in range(self.max_iter):
self.train_step()
test_acc = self.network.accuracy(self.x_test, self.t_test)
if self.verbose:
print("=============== Final Test Accuracy ===============")
print("test acc:" + str(test_acc))
class Nesterov:
"""Nesterov's Accelerated Gradient (http://arxiv.org/abs/1212.0901)"""
def __init__(self, lr=0.01, momentum=0.9):
self.lr = lr
self.momentum = momentum
self.v = None
def update(self, params, grads):
if self.v is None:
self.v = {}
for key, val in params.items():
self.v[key] = np.zeros_like(val)
for key in params.keys():
self.v[key] *= self.momentum
self.v[key] -= self.lr * grads[key]
params[key] += self.momentum * self.momentum * self.v[key]
params[key] -= (1 + self.momentum) * self.lr * grads[key]
class RMSprop:
"""RMSprop"""
def __init__(self, lr=0.01, decay_rate = 0.99):
self.lr = lr
self.decay_rate = decay_rate
self.h = None
def update(self, params, grads):
if self.h is None:
self.h = {}
for key, val in params.items():
self.h[key] = np.zeros_like(val)
for key in params.keys():
self.h[key] *= self.decay_rate
self.h[key] += (1 - self.decay_rate) * grads[key] * grads[key]
params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 過学習を再現するために、学習データを削減
x_train = x_train[:300]
t_train = t_train[:300]
# Dropuoutの有無、割り合いの設定 ========================
use_dropout = True # Dropoutなしのときの場合はFalseに
dropout_ratio = 0.2
# ====================================================
network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
output_size=10, use_dropout=use_dropout, dropout_ration=dropout_ratio)
trainer = Trainer(network, x_train, t_train, x_test, t_test,
epochs=301, mini_batch_size=100,
optimizer='sgd', optimizer_param={'lr': 0.01}, verbose=True)
trainer.train()
train_acc_list, test_acc_list = trainer.train_acc_list, trainer.test_acc_list
# グラフの描画==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
train loss:2.312167959 === epoch:1, train acc:0.0466666666667, test acc:0.082 === train loss:2.32410263111 train loss:2.31945279883 train loss:2.30768186918 === epoch:2, train acc:0.0433333333333, test acc:0.0809 === train loss:2.30158195902 train loss:2.30012869297 train loss:2.31460195306 === epoch:3, train acc:0.04, test acc:0.0781 === train loss:2.30110057525 train loss:2.29054599374 train loss:2.31063401819 === epoch:4, train acc:0.04, test acc:0.0751 === train loss:2.30010960583 train loss:2.31345330727 train loss:2.30922172308 === epoch:5, train acc:0.0366666666667, test acc:0.0739 === train loss:2.30930434835 train loss:2.31743112559 train loss:2.30333256562 === epoch:6, train acc:0.0366666666667, test acc:0.0734 === train loss:2.29985084381 train loss:2.29617781671 train loss:2.32485383593 === epoch:7, train acc:0.0333333333333, test acc:0.0741 === train loss:2.30790939081 train loss:2.30808132487 train loss:2.31001609375 === epoch:8, train acc:0.03, test acc:0.075 === train loss:2.29978990466 train loss:2.29538433837 train loss:2.30077830861 === epoch:9, train acc:0.0333333333333, test acc:0.0759 === train loss:2.30963509201 train loss:2.29691032269 train loss:2.30358975153 === epoch:10, train acc:0.04, test acc:0.0749 === train loss:2.30434765416 train loss:2.3128000936 train loss:2.30277658166 === epoch:11, train acc:0.0466666666667, test acc:0.077 === train loss:2.29206419102 train loss:2.28810378376 train loss:2.30033367295 === epoch:12, train acc:0.0533333333333, test acc:0.0776 === train loss:2.29689363595 train loss:2.30424121048 train loss:2.29562983221 === epoch:13, train acc:0.0566666666667, test acc:0.0776 === train loss:2.29101804589 train loss:2.28985023201 train loss:2.29501393065 === epoch:14, train acc:0.0566666666667, test acc:0.0784 === train loss:2.28816362129 train loss:2.30045928072 train loss:2.29577304565 === epoch:15, train acc:0.0666666666667, test acc:0.0812 === train loss:2.29964370874 train loss:2.29033226094 train loss:2.28635095392 === epoch:16, train acc:0.0766666666667, test acc:0.0831 === train loss:2.28727678523 train loss:2.29586419644 train loss:2.30088717057 === epoch:17, train acc:0.0833333333333, test acc:0.088 === train loss:2.2988141827 train loss:2.29696462384 train loss:2.28724752933 === epoch:18, train acc:0.0833333333333, test acc:0.0905 === train loss:2.29755245766 train loss:2.29515182724 train loss:2.28917269369 === epoch:19, train acc:0.0866666666667, test acc:0.0919 === train loss:2.27519966605 train loss:2.28381344693 train loss:2.28936385708 === epoch:20, train acc:0.0866666666667, test acc:0.0946 === train loss:2.28837904218 train loss:2.29065970247 train loss:2.29310327736 === epoch:21, train acc:0.0933333333333, test acc:0.0932 === train loss:2.29260416428 train loss:2.28237772883 train loss:2.29664470107 === epoch:22, train acc:0.1, test acc:0.0946 === train loss:2.28323708882 train loss:2.29612699482 train loss:2.28922704849 === epoch:23, train acc:0.106666666667, test acc:0.0961 === train loss:2.28874770796 train loss:2.2929471227 train loss:2.29482883364 === epoch:24, train acc:0.11, test acc:0.1 === train loss:2.28987716681 train loss:2.28681436619 train loss:2.28140179361 === epoch:25, train acc:0.116666666667, test acc:0.1026 === train loss:2.28358910323 train loss:2.28645745307 train loss:2.28845842185 === epoch:26, train acc:0.116666666667, test acc:0.1036 === train loss:2.27049142398 train loss:2.2804288742 train loss:2.28580433461 === epoch:27, train acc:0.13, test acc:0.1044 === train loss:2.28247156603 train loss:2.28795865196 train loss:2.28995132661 === epoch:28, train acc:0.13, test acc:0.1051 === train loss:2.28058223068 train loss:2.28893087028 train loss:2.28507180489 === epoch:29, train acc:0.13, test acc:0.1062 === train loss:2.27639714051 train loss:2.27116190608 train loss:2.27468232512 === epoch:30, train acc:0.133333333333, test acc:0.1086 === train loss:2.27800541717 train loss:2.28294752901 train loss:2.28192221854 === epoch:31, train acc:0.126666666667, test acc:0.11 === train loss:2.28647541865 train loss:2.28207684684 train loss:2.28791054256 === epoch:32, train acc:0.133333333333, test acc:0.1119 === train loss:2.26781300553 train loss:2.28359501573 train loss:2.27646319625 === epoch:33, train acc:0.153333333333, test acc:0.1139 === train loss:2.27828895051 train loss:2.28177177524 train loss:2.28384538468 === epoch:34, train acc:0.153333333333, test acc:0.1165 === train loss:2.2657835914 train loss:2.26637431625 train loss:2.27678691879 === epoch:35, train acc:0.163333333333, test acc:0.12 === train loss:2.2782645555 train loss:2.28229648539 train loss:2.27396137364 === epoch:36, train acc:0.166666666667, test acc:0.1232 === train loss:2.27901080961 train loss:2.27923018631 train loss:2.27606612117 === epoch:37, train acc:0.16, test acc:0.1235 === train loss:2.25613843406 train loss:2.26643763221 train loss:2.28097935888 === epoch:38, train acc:0.16, test acc:0.1257 === train loss:2.2669801285 train loss:2.25926422731 train loss:2.26732479078 === epoch:39, train acc:0.16, test acc:0.1271 === train loss:2.27052487498 train loss:2.25975857724 train loss:2.2658442553 === epoch:40, train acc:0.17, test acc:0.129 === train loss:2.28373669359 train loss:2.26776105944 train loss:2.27474586068 === epoch:41, train acc:0.173333333333, test acc:0.1327 === train loss:2.2718925283 train loss:2.28169736292 train loss:2.27238953545 === epoch:42, train acc:0.173333333333, test acc:0.1359 === train loss:2.26589889621 train loss:2.25532579431 train loss:2.27857987384 === epoch:43, train acc:0.173333333333, test acc:0.1393 === train loss:2.26436948077 train loss:2.25999224104 train loss:2.25623163385 === epoch:44, train acc:0.186666666667, test acc:0.1452 === train loss:2.26412860442 train loss:2.26407981687 train loss:2.27161661493 === epoch:45, train acc:0.186666666667, test acc:0.1459 === train loss:2.28000099258 train loss:2.27119385748 train loss:2.26505254372 === epoch:46, train acc:0.19, test acc:0.1464 === train loss:2.26584194308 train loss:2.27492175854 train loss:2.26500423965 === epoch:47, train acc:0.21, test acc:0.1539 === train loss:2.26976683423 train loss:2.27309546922 train loss:2.2676794836 === epoch:48, train acc:0.21, test acc:0.1563 === train loss:2.24797801301 train loss:2.25752924762 train loss:2.26529459314 === epoch:49, train acc:0.213333333333, test acc:0.1567 === train loss:2.25980015435 train loss:2.26532612933 train loss:2.25353979323 === epoch:50, train acc:0.216666666667, test acc:0.159 === train loss:2.25207384295 train loss:2.2667681401 train loss:2.24477231116 === epoch:51, train acc:0.216666666667, test acc:0.1659 === train loss:2.25116844868 train loss:2.27087359828 train loss:2.26482534687 === epoch:52, train acc:0.223333333333, test acc:0.1685 === train loss:2.26661119075 train loss:2.26365160287 train loss:2.25958515388 === epoch:53, train acc:0.23, test acc:0.1693 === train loss:2.25067683382 train loss:2.258463674 train loss:2.26442090264 === epoch:54, train acc:0.23, test acc:0.1724 === train loss:2.25290681026 train loss:2.26338157438 train loss:2.25440360227 === epoch:55, train acc:0.24, test acc:0.1779 === train loss:2.26021509636 train loss:2.25569940077 train loss:2.26437106205 === epoch:56, train acc:0.243333333333, test acc:0.1805 === train loss:2.25793985295 train loss:2.2555372669 train loss:2.25280615936 === epoch:57, train acc:0.243333333333, test acc:0.1834 === train loss:2.2598511082 train loss:2.2600733623 train loss:2.2336865994 === epoch:58, train acc:0.246666666667, test acc:0.1895 === train loss:2.26445680314 train loss:2.25246222534 train loss:2.2469250572 === epoch:59, train acc:0.246666666667, test acc:0.1892 === train loss:2.24659161027 train loss:2.2593311364 train loss:2.23803721459 === epoch:60, train acc:0.256666666667, test acc:0.1938 === train loss:2.2447798063 train loss:2.24040132192 train loss:2.25376342075 === epoch:61, train acc:0.26, test acc:0.2012 === train loss:2.25474930297 train loss:2.25248686321 train loss:2.22697070132 === epoch:62, train acc:0.253333333333, test acc:0.2076 === train loss:2.23833550014 train loss:2.23758680964 train loss:2.25170163764 === epoch:63, train acc:0.263333333333, test acc:0.2127 === train loss:2.2400577216 train loss:2.24762374289 train loss:2.24047619397 === epoch:64, train acc:0.26, test acc:0.2162 === train loss:2.24978658921 train loss:2.23979207849 train loss:2.26172333531 === epoch:65, train acc:0.266666666667, test acc:0.2175 === train loss:2.23317925142 train loss:2.23048652638 train loss:2.24985870061 === epoch:66, train acc:0.28, test acc:0.2244 === train loss:2.25537299503 train loss:2.23833194111 train loss:2.22309941462 === epoch:67, train acc:0.273333333333, test acc:0.2254 === train loss:2.25797877146 train loss:2.23873538149 train loss:2.22200421778 === epoch:68, train acc:0.276666666667, test acc:0.2269 === train loss:2.22324736725 train loss:2.24025411103 train loss:2.23343883625 === epoch:69, train acc:0.28, test acc:0.2273 === train loss:2.22459241854 train loss:2.24152275223 train loss:2.23349875682 === epoch:70, train acc:0.283333333333, test acc:0.2285 === train loss:2.22072599855 train loss:2.23522036524 train loss:2.23053064744 === epoch:71, train acc:0.29, test acc:0.2255 === train loss:2.23935011041 train loss:2.2310468285 train loss:2.24077210841 === epoch:72, train acc:0.293333333333, test acc:0.228 === train loss:2.23092801451 train loss:2.23017648564 train loss:2.23660065004 === epoch:73, train acc:0.286666666667, test acc:0.2313 === train loss:2.23495768042 train loss:2.23724850639 train loss:2.23529983362 === epoch:74, train acc:0.303333333333, test acc:0.2344 === train loss:2.24366075823 train loss:2.23401169482 train loss:2.24155640114 === epoch:75, train acc:0.3, test acc:0.2376 === train loss:2.22601020527 train loss:2.22369207816 train loss:2.23884544111 === epoch:76, train acc:0.316666666667, test acc:0.2442 === train loss:2.23614565462 train loss:2.225892674 train loss:2.22602141735 === epoch:77, train acc:0.316666666667, test acc:0.2432 === train loss:2.22303822519 train loss:2.24750097957 train loss:2.2414148075 === epoch:78, train acc:0.316666666667, test acc:0.2448 === train loss:2.21303270555 train loss:2.23074992772 train loss:2.20981884256 === epoch:79, train acc:0.32, test acc:0.2492 === train loss:2.22917378845 train loss:2.22643852283 train loss:2.22527787874 === epoch:80, train acc:0.32, test acc:0.2506 === train loss:2.21726472575 train loss:2.22051078003 train loss:2.20595755696 === epoch:81, train acc:0.32, test acc:0.2523 === train loss:2.22566740009 train loss:2.19034590456 train loss:2.22699478315 === epoch:82, train acc:0.32, test acc:0.257 === train loss:2.23277198125 train loss:2.23499849057 train loss:2.2190448111 === epoch:83, train acc:0.32, test acc:0.2561 === train loss:2.21439378206 train loss:2.20091987683 train loss:2.20711917837 === epoch:84, train acc:0.32, test acc:0.2614 === train loss:2.20743632201 train loss:2.21666159492 train loss:2.20307802216 === epoch:85, train acc:0.32, test acc:0.2653 === train loss:2.21777504025 train loss:2.20971368864 train loss:2.20292948007 === epoch:86, train acc:0.32, test acc:0.2682 === train loss:2.20233746724 train loss:2.19513178805 train loss:2.20805194668 === epoch:87, train acc:0.323333333333, test acc:0.2704 === train loss:2.22413739391 train loss:2.18821013835 train loss:2.21622649684 === epoch:88, train acc:0.33, test acc:0.2739 === train loss:2.2002240339 train loss:2.19884477995 train loss:2.2113401231 === epoch:89, train acc:0.333333333333, test acc:0.2759 === train loss:2.18528398012 train loss:2.19793450986 train loss:2.22094417765 === epoch:90, train acc:0.333333333333, test acc:0.2758 === train loss:2.21669260341 train loss:2.19307008837 train loss:2.20539861041 === epoch:91, train acc:0.333333333333, test acc:0.2786 === train loss:2.18088500289 train loss:2.23787945854 train loss:2.19117389612 === epoch:92, train acc:0.336666666667, test acc:0.2803 === train loss:2.19838717715 train loss:2.20014437804 train loss:2.18062303091 === epoch:93, train acc:0.346666666667, test acc:0.2808 === train loss:2.21769655825 train loss:2.17351765661 train loss:2.21217138395 === epoch:94, train acc:0.346666666667, test acc:0.2802 === train loss:2.16051830158 train loss:2.21025037531 train loss:2.17452932138 === epoch:95, train acc:0.346666666667, test acc:0.2826 === train loss:2.19846005954 train loss:2.15103355577 train loss:2.13979864445 === epoch:96, train acc:0.346666666667, test acc:0.2842 === train loss:2.19716296994 train loss:2.20115866615 train loss:2.15739535248 === epoch:97, train acc:0.35, test acc:0.2848 === train loss:2.17337017636 train loss:2.2116908188 train loss:2.16209386264 === epoch:98, train acc:0.353333333333, test acc:0.2868 === train loss:2.19914404933 train loss:2.21162927146 train loss:2.18337163123 === epoch:99, train acc:0.35, test acc:0.288 === train loss:2.20487762183 train loss:2.17075250526 train loss:2.17136361893 === epoch:100, train acc:0.356666666667, test acc:0.2893 === train loss:2.20189961255 train loss:2.18088839828 train loss:2.15533559619 === epoch:101, train acc:0.343333333333, test acc:0.2915 === train loss:2.16532941354 train loss:2.17285962314 train loss:2.16609333631 === epoch:102, train acc:0.356666666667, test acc:0.2916 === train loss:2.17459092525 train loss:2.16770737189 train loss:2.19117154534 === epoch:103, train acc:0.36, test acc:0.2921 === train loss:2.17939586577 train loss:2.1920686638 train loss:2.18100750877 === epoch:104, train acc:0.36, test acc:0.2948 === train loss:2.14860205523 train loss:2.16786966704 train loss:2.15636857283 === epoch:105, train acc:0.36, test acc:0.2939 === train loss:2.17512082566 train loss:2.17085363646 train loss:2.17071527472 === epoch:106, train acc:0.36, test acc:0.298 === train loss:2.16215878324 train loss:2.14790506826 train loss:2.16454748343 === epoch:107, train acc:0.36, test acc:0.2971 === train loss:2.14394170529 train loss:2.130508827 train loss:2.17607608036 === epoch:108, train acc:0.373333333333, test acc:0.3016 === train loss:2.13795540245 train loss:2.15017195235 train loss:2.14512950425 === epoch:109, train acc:0.38, test acc:0.3059 === train loss:2.1400467201 train loss:2.17702628981 train loss:2.15184246531 === epoch:110, train acc:0.38, test acc:0.3071 === train loss:2.12704246346 train loss:2.13456337661 train loss:2.10489472521 === epoch:111, train acc:0.376666666667, test acc:0.3067 === train loss:2.15091480471 train loss:2.11140118546 train loss:2.14034687047 === epoch:112, train acc:0.373333333333, test acc:0.3076 === train loss:2.16163507741 train loss:2.11260644873 train loss:2.14937377663 === epoch:113, train acc:0.373333333333, test acc:0.3094 === train loss:2.15140439167 train loss:2.13565258641 train loss:2.11696934537 === epoch:114, train acc:0.376666666667, test acc:0.3094 === train loss:2.10550493215 train loss:2.12978665011 train loss:2.09384098723 === epoch:115, train acc:0.373333333333, test acc:0.309 === train loss:2.18200659019 train loss:2.14870057776 train loss:2.11789493168 === epoch:116, train acc:0.39, test acc:0.3132 === train loss:2.082724691 train loss:2.06943319649 train loss:2.10318768502 === epoch:117, train acc:0.376666666667, test acc:0.311 === train loss:2.10632835153 train loss:2.14882145327 train loss:2.10592263179 === epoch:118, train acc:0.376666666667, test acc:0.3095 === train loss:2.13383046423 train loss:2.1308019041 train loss:2.10848886702 === epoch:119, train acc:0.383333333333, test acc:0.3119 === train loss:2.11482214822 train loss:2.09551049524 train loss:2.16229975773 === epoch:120, train acc:0.393333333333, test acc:0.317 === train loss:2.12570073449 train loss:2.14335388599 train loss:2.12631099719 === epoch:121, train acc:0.406666666667, test acc:0.321 === train loss:2.10307999948 train loss:2.09971973459 train loss:2.13992303733 === epoch:122, train acc:0.4, test acc:0.3194 === train loss:2.10073367682 train loss:2.11607121646 train loss:2.10508840166 === epoch:123, train acc:0.403333333333, test acc:0.3233 === train loss:2.03204342104 train loss:2.1092212565 train loss:2.17246864853 === epoch:124, train acc:0.416666666667, test acc:0.3292 === train loss:2.11708567396 train loss:2.1361506648 train loss:2.11478121911 === epoch:125, train acc:0.413333333333, test acc:0.3287 === train loss:2.08969424722 train loss:2.09856929369 train loss:2.08963705447 === epoch:126, train acc:0.413333333333, test acc:0.3305 === train loss:2.12573885269 train loss:2.14452064855 train loss:2.10284850173 === epoch:127, train acc:0.43, test acc:0.3352 === train loss:2.09459263542 train loss:2.04612026896 train loss:2.09929641432 === epoch:128, train acc:0.43, test acc:0.3352 === train loss:2.0181144727 train loss:2.06004302625 train loss:2.0438469824 === epoch:129, train acc:0.43, test acc:0.3335 === train loss:2.04926725905 train loss:2.08499830432 train loss:2.11163481071 === epoch:130, train acc:0.436666666667, test acc:0.3408 === train loss:2.01152547382 train loss:2.02057930431 train loss:2.09532066292 === epoch:131, train acc:0.433333333333, test acc:0.3386 === train loss:2.09042176009 train loss:2.08129384642 train loss:2.04581898503 === epoch:132, train acc:0.436666666667, test acc:0.3424 === train loss:2.06567490667 train loss:2.04393614619 train loss:2.07882409089 === epoch:133, train acc:0.446666666667, test acc:0.3496 === train loss:2.02502653841 train loss:2.07347531146 train loss:2.05583857389 === epoch:134, train acc:0.453333333333, test acc:0.3528 === train loss:2.0360435774 train loss:2.05393210653 train loss:2.08372712061 === epoch:135, train acc:0.463333333333, test acc:0.3503 === train loss:2.07865348491 train loss:2.00621599767 train loss:2.04975729154 === epoch:136, train acc:0.46, test acc:0.3512 === train loss:1.99499273026 train loss:2.04468308035 train loss:2.05889515959 === epoch:137, train acc:0.46, test acc:0.3554 === train loss:2.01700922853 train loss:2.03282554396 train loss:2.04675905658 === epoch:138, train acc:0.46, test acc:0.3515 === train loss:2.06975149797 train loss:2.04816108094 train loss:2.02646271352 === epoch:139, train acc:0.456666666667, test acc:0.3474 === train loss:2.04007746299 train loss:2.0695294617 train loss:2.05291134262 === epoch:140, train acc:0.453333333333, test acc:0.3471 === train loss:2.02349625527 train loss:2.02239541298 train loss:2.00633515281 === epoch:141, train acc:0.456666666667, test acc:0.3452 === train loss:2.05211976929 train loss:2.09717555756 train loss:2.0152096438 === epoch:142, train acc:0.46, test acc:0.348 === train loss:1.9979433073 train loss:1.97398531676 train loss:2.0534786899 === epoch:143, train acc:0.456666666667, test acc:0.3468 === train loss:2.01841206784 train loss:2.10600406842 train loss:1.98639861352 === epoch:144, train acc:0.463333333333, test acc:0.3551 === train loss:1.9259312503 train loss:2.09349122072 train loss:2.00506741147 === epoch:145, train acc:0.463333333333, test acc:0.3539 === train loss:1.97863804353 train loss:2.03251182821 train loss:1.98445128691 === epoch:146, train acc:0.463333333333, test acc:0.3565 === train loss:2.00397875696 train loss:1.96497977849 train loss:2.02301014048 === epoch:147, train acc:0.456666666667, test acc:0.3565 === train loss:2.11503727626 train loss:1.89803582821 train loss:1.99182354859 === epoch:148, train acc:0.47, test acc:0.3588 === train loss:2.02437995175 train loss:1.99477594387 train loss:1.9959018553 === epoch:149, train acc:0.47, test acc:0.3589 === train loss:2.00862921812 train loss:1.92534069548 train loss:2.0122237861 === epoch:150, train acc:0.473333333333, test acc:0.3649 === train loss:1.99371714081 train loss:2.02826001777 train loss:2.00734743166 === epoch:151, train acc:0.486666666667, test acc:0.3659 === train loss:1.91504587547 train loss:1.90892167683 train loss:1.96329148785 === epoch:152, train acc:0.486666666667, test acc:0.366 === train loss:2.01066319958 train loss:1.96390199197 train loss:1.88207216044 === epoch:153, train acc:0.476666666667, test acc:0.3602 === train loss:2.01987549123 train loss:1.93561845512 train loss:1.99849234688 === epoch:154, train acc:0.48, test acc:0.3652 === train loss:1.90399626786 train loss:1.8231161262 train loss:1.94374456021 === epoch:155, train acc:0.493333333333, test acc:0.3682 === train loss:2.00432742724 train loss:1.95718575096 train loss:1.98220077784 === epoch:156, train acc:0.49, test acc:0.3682 === train loss:1.97336842514 train loss:1.98797834614 train loss:1.95551245036 === epoch:157, train acc:0.48, test acc:0.3661 === train loss:1.94491630471 train loss:1.94663411673 train loss:1.96370803638 === epoch:158, train acc:0.483333333333, test acc:0.3666 === train loss:1.95159781114 train loss:1.87581721435 train loss:1.89539004168 === epoch:159, train acc:0.483333333333, test acc:0.3673 === train loss:1.94957746124 train loss:1.97281168855 train loss:1.91188071922 === epoch:160, train acc:0.493333333333, test acc:0.3746 === train loss:1.94148177627 train loss:1.98551599292 train loss:2.00044664124 === epoch:161, train acc:0.496666666667, test acc:0.3796 === train loss:1.92277728764 train loss:1.96524352141 train loss:1.87900530263 === epoch:162, train acc:0.493333333333, test acc:0.3792 === train loss:1.94027667018 train loss:1.95435852682 train loss:1.88966261751 === epoch:163, train acc:0.496666666667, test acc:0.3755 === train loss:1.94063687174 train loss:1.95887594634 train loss:1.91406410564 === epoch:164, train acc:0.5, test acc:0.3817 === train loss:1.93228191013 train loss:1.89250855381 train loss:1.91353017675 === epoch:165, train acc:0.49, test acc:0.3803 === train loss:1.94981272384 train loss:1.9136145788 train loss:1.90734403398 === epoch:166, train acc:0.493333333333, test acc:0.384 === train loss:1.92691655358 train loss:1.77051914322 train loss:1.87408567903 === epoch:167, train acc:0.496666666667, test acc:0.391 === train loss:1.85479969512 train loss:1.89095669627 train loss:1.86342122011 === epoch:168, train acc:0.5, test acc:0.3946 === train loss:1.98730674867 train loss:1.86283222905 train loss:1.77770788823 === epoch:169, train acc:0.503333333333, test acc:0.3978 === train loss:1.81410855116 train loss:1.89434915761 train loss:1.84880166318 === epoch:170, train acc:0.503333333333, test acc:0.3995 === train loss:1.89889475812 train loss:1.88686170918 train loss:1.8917731559 === epoch:171, train acc:0.51, test acc:0.4017 === train loss:1.80131126895 train loss:1.80848175595 train loss:1.79157768331 === epoch:172, train acc:0.503333333333, test acc:0.397 === train loss:1.88741544061 train loss:1.90780577945 train loss:1.89990982724 === epoch:173, train acc:0.503333333333, test acc:0.4007 === train loss:1.83461077856 train loss:1.77665661849 train loss:1.89388596478 === epoch:174, train acc:0.516666666667, test acc:0.4002 === train loss:1.88046396084 train loss:1.91184904923 train loss:1.81437516262 === epoch:175, train acc:0.513333333333, test acc:0.4049 === train loss:1.94981824942 train loss:1.84969874565 train loss:1.8693913821 === epoch:176, train acc:0.5, test acc:0.4026 === train loss:1.76505313743 train loss:1.8878431719 train loss:1.82628472835 === epoch:177, train acc:0.513333333333, test acc:0.4093 === train loss:1.70647004047 train loss:1.79198325368 train loss:1.80247202439 === epoch:178, train acc:0.526666666667, test acc:0.4185 === train loss:1.91825838506 train loss:1.8217331248 train loss:1.88130018334 === epoch:179, train acc:0.536666666667, test acc:0.4252 === train loss:1.9099195328 train loss:1.84710544735 train loss:1.82091167104 === epoch:180, train acc:0.536666666667, test acc:0.4247 === train loss:1.79234095523 train loss:1.64410201615 train loss:1.73689583834 === epoch:181, train acc:0.533333333333, test acc:0.4249 === train loss:1.81723512039 train loss:1.82865042618 train loss:1.69722162482 === epoch:182, train acc:0.536666666667, test acc:0.426 === train loss:1.7552465933 train loss:1.81119297452 train loss:1.82388412289 === epoch:183, train acc:0.536666666667, test acc:0.4238 === train loss:1.75029776497 train loss:1.77409258561 train loss:1.77606211233 === epoch:184, train acc:0.54, test acc:0.4278 === train loss:1.79311089788 train loss:1.88283262502 train loss:1.8423203388 === epoch:185, train acc:0.536666666667, test acc:0.4289 === train loss:1.83382256111 train loss:1.73590205016 train loss:1.91358705262 === epoch:186, train acc:0.54, test acc:0.4316 === train loss:1.89561013546 train loss:1.76419758463 train loss:1.77814771664 === epoch:187, train acc:0.546666666667, test acc:0.4349 === train loss:1.8391230033 train loss:1.87027399117 train loss:1.80550056293 === epoch:188, train acc:0.55, test acc:0.4383 === train loss:1.74295211208 train loss:1.73821164935 train loss:1.81579347712 === epoch:189, train acc:0.55, test acc:0.4344 === train loss:1.70314876953 train loss:1.81696013959 train loss:1.83806454239 === epoch:190, train acc:0.55, test acc:0.4344 === train loss:1.7660859962 train loss:1.69706092817 train loss:1.76350633663 === epoch:191, train acc:0.546666666667, test acc:0.4341 === train loss:1.66754567156 train loss:1.70837625285 train loss:1.6822382688 === epoch:192, train acc:0.546666666667, test acc:0.435 === train loss:1.69205931872 train loss:1.84251584456 train loss:1.70362607985 === epoch:193, train acc:0.54, test acc:0.4338 === train loss:1.87591878021 train loss:1.67371329017 train loss:1.81941584735 === epoch:194, train acc:0.543333333333, test acc:0.4352 === train loss:1.77143785713 train loss:1.86124437858 train loss:1.73173280992 === epoch:195, train acc:0.536666666667, test acc:0.4384 === train loss:1.66873025067 train loss:1.67628546751 train loss:1.70418927498 === epoch:196, train acc:0.546666666667, test acc:0.4345 === train loss:1.77134321144 train loss:1.72097342603 train loss:1.7171644484 === epoch:197, train acc:0.54, test acc:0.4374 === train loss:1.72692247386 train loss:1.68812643621 train loss:1.69743509988 === epoch:198, train acc:0.536666666667, test acc:0.4427 === train loss:1.74034815546 train loss:1.77466676196 train loss:1.77846609543 === epoch:199, train acc:0.536666666667, test acc:0.4412 === train loss:1.70265209177 train loss:1.74908815435 train loss:1.79999563724 === epoch:200, train acc:0.54, test acc:0.4414 === train loss:1.60831338508 train loss:1.59246195473 train loss:1.5918674579 === epoch:201, train acc:0.54, test acc:0.4433 === train loss:1.76113691364 train loss:1.70361221964 train loss:1.60700269551 === epoch:202, train acc:0.54, test acc:0.4434 === train loss:1.77644433933 train loss:1.63646808832 train loss:1.78334265013 === epoch:203, train acc:0.536666666667, test acc:0.4447 === train loss:1.63739700016 train loss:1.65972636938 train loss:1.56369496233 === epoch:204, train acc:0.55, test acc:0.4482 === train loss:1.69931382155 train loss:1.68138493862 train loss:1.62256929647 === epoch:205, train acc:0.543333333333, test acc:0.4455 === train loss:1.67679208466 train loss:1.72939124652 train loss:1.72683992203 === epoch:206, train acc:0.543333333333, test acc:0.4429 === train loss:1.74214781981 train loss:1.75672517102 train loss:1.6585555428 === epoch:207, train acc:0.55, test acc:0.4451 === train loss:1.67304902056 train loss:1.60723742871 train loss:1.64531610946 === epoch:208, train acc:0.546666666667, test acc:0.4442 === train loss:1.72950904174 train loss:1.62972411964 train loss:1.58240642904 === epoch:209, train acc:0.55, test acc:0.443 === train loss:1.7302333823 train loss:1.63948736968 train loss:1.62356072786 === epoch:210, train acc:0.546666666667, test acc:0.4401 === train loss:1.68470899796 train loss:1.69036412255 train loss:1.47215979985 === epoch:211, train acc:0.546666666667, test acc:0.4416 === train loss:1.7934460917 train loss:1.58951534814 train loss:1.63481242074 === epoch:212, train acc:0.546666666667, test acc:0.4431 === train loss:1.73288917653 train loss:1.7045350512 train loss:1.54303214903 === epoch:213, train acc:0.546666666667, test acc:0.4461 === train loss:1.75531286067 train loss:1.72262154874 train loss:1.61964315295 === epoch:214, train acc:0.546666666667, test acc:0.4473 === train loss:1.6689488058 train loss:1.69828049036 train loss:1.57797366255 === epoch:215, train acc:0.546666666667, test acc:0.448 === train loss:1.6501080012 train loss:1.63504122366 train loss:1.57866844934 === epoch:216, train acc:0.54, test acc:0.4498 === train loss:1.53551006035 train loss:1.61633177006 train loss:1.67801223179 === epoch:217, train acc:0.55, test acc:0.4544 === train loss:1.55118938693 train loss:1.6560027536 train loss:1.6433202579 === epoch:218, train acc:0.55, test acc:0.454 === train loss:1.59573700898 train loss:1.70932244627 train loss:1.7728571372 === epoch:219, train acc:0.55, test acc:0.4581 === train loss:1.70211413226 train loss:1.66269621773 train loss:1.6627671115 === epoch:220, train acc:0.56, test acc:0.4634 === train loss:1.6013120079 train loss:1.65984041767 train loss:1.62875014763 === epoch:221, train acc:0.556666666667, test acc:0.4648 === train loss:1.69836345871 train loss:1.5997399499 train loss:1.63697047229 === epoch:222, train acc:0.566666666667, test acc:0.4656 === train loss:1.61472907013 train loss:1.55290895015 train loss:1.48788757226 === epoch:223, train acc:0.556666666667, test acc:0.4629 === train loss:1.6025332351 train loss:1.56875703356 train loss:1.53723288436 === epoch:224, train acc:0.556666666667, test acc:0.4619 === train loss:1.63125487756 train loss:1.59779919559 train loss:1.48883754001 === epoch:225, train acc:0.566666666667, test acc:0.466 === train loss:1.58629590197 train loss:1.58006408821 train loss:1.53423403697 === epoch:226, train acc:0.563333333333, test acc:0.4682 === train loss:1.5440488184 train loss:1.49350482336 train loss:1.55626832631 === epoch:227, train acc:0.563333333333, test acc:0.4646 === train loss:1.6053326908 train loss:1.47363141518 train loss:1.5100961776 === epoch:228, train acc:0.563333333333, test acc:0.4651 === train loss:1.52981460184 train loss:1.42436124088 train loss:1.42699349761 === epoch:229, train acc:0.55, test acc:0.4635 === train loss:1.5764586072 train loss:1.58807486062 train loss:1.54154193317 === epoch:230, train acc:0.563333333333, test acc:0.4664 === train loss:1.57525322281 train loss:1.61826327653 train loss:1.51866691277 === epoch:231, train acc:0.556666666667, test acc:0.4674 === train loss:1.65106208317 train loss:1.65444357801 train loss:1.57924710996 === epoch:232, train acc:0.566666666667, test acc:0.4742 === train loss:1.52662125587 train loss:1.58969411968 train loss:1.54185457443 === epoch:233, train acc:0.57, test acc:0.4786 === train loss:1.63356206219 train loss:1.55548404963 train loss:1.62642897691 === epoch:234, train acc:0.576666666667, test acc:0.4829 === train loss:1.51423075723 train loss:1.46420171525 train loss:1.61266522081 === epoch:235, train acc:0.576666666667, test acc:0.4856 === train loss:1.62814059553 train loss:1.52059026049 train loss:1.55306242453 === epoch:236, train acc:0.573333333333, test acc:0.4837 === train loss:1.47937598243 train loss:1.53100295084 train loss:1.6722067309 === epoch:237, train acc:0.56, test acc:0.482 === train loss:1.40315372606 train loss:1.54349061231 train loss:1.63385793851 === epoch:238, train acc:0.563333333333, test acc:0.4843 === train loss:1.55487294712 train loss:1.42286152196 train loss:1.47973256163 === epoch:239, train acc:0.573333333333, test acc:0.4827 === train loss:1.57461099346 train loss:1.36113292133 train loss:1.50335495203 === epoch:240, train acc:0.57, test acc:0.4844 === train loss:1.53988431861 train loss:1.4811824305 train loss:1.50016380987 === epoch:241, train acc:0.563333333333, test acc:0.4835 === train loss:1.47209529963 train loss:1.44013795573 train loss:1.52316978714 === epoch:242, train acc:0.573333333333, test acc:0.4879 === train loss:1.55911401691 train loss:1.47878793596 train loss:1.46404174601 === epoch:243, train acc:0.57, test acc:0.4894 === train loss:1.47158515798 train loss:1.43156040894 train loss:1.46116719296 === epoch:244, train acc:0.58, test acc:0.4914 === train loss:1.49600101568 train loss:1.47067129597 train loss:1.45582159849 === epoch:245, train acc:0.576666666667, test acc:0.4903 === train loss:1.4752301474 train loss:1.55651291986 train loss:1.50126707292 === epoch:246, train acc:0.573333333333, test acc:0.4924 === train loss:1.55805367779 train loss:1.46511530387 train loss:1.51374912594 === epoch:247, train acc:0.583333333333, test acc:0.4946 === train loss:1.45182203371 train loss:1.487151071 train loss:1.39640036134 === epoch:248, train acc:0.586666666667, test acc:0.4976 === train loss:1.37026003713 train loss:1.37426758312 train loss:1.47905848181 === epoch:249, train acc:0.583333333333, test acc:0.4984 === train loss:1.34786027819 train loss:1.49820115162 train loss:1.30679571537 === epoch:250, train acc:0.58, test acc:0.4974 === train loss:1.41389403852 train loss:1.53046823405 train loss:1.43948333389 === epoch:251, train acc:0.586666666667, test acc:0.5013 === train loss:1.57338406189 train loss:1.53023562099 train loss:1.27550643356 === epoch:252, train acc:0.593333333333, test acc:0.5038 === train loss:1.48782860698 train loss:1.37396677749 train loss:1.53945548867 === epoch:253, train acc:0.586666666667, test acc:0.5021 === train loss:1.43676908004 train loss:1.22960501032 train loss:1.50009130738 === epoch:254, train acc:0.59, test acc:0.503 === train loss:1.551483481 train loss:1.35831873366 train loss:1.48355534428 === epoch:255, train acc:0.593333333333, test acc:0.5031 === train loss:1.47108208689 train loss:1.45544169965 train loss:1.33358853535 === epoch:256, train acc:0.596666666667, test acc:0.5041 === train loss:1.34966022222 train loss:1.44256997847 train loss:1.45338733261 === epoch:257, train acc:0.59, test acc:0.5046 === train loss:1.3168171586 train loss:1.39080441046 train loss:1.35559303684 === epoch:258, train acc:0.596666666667, test acc:0.5063 === train loss:1.54545084865 train loss:1.35125991205 train loss:1.47140084729 === epoch:259, train acc:0.596666666667, test acc:0.5093 === train loss:1.4145138522 train loss:1.34366672248 train loss:1.47367032369 === epoch:260, train acc:0.6, test acc:0.5096 === train loss:1.33916932286 train loss:1.44342231247 train loss:1.27308943499 === epoch:261, train acc:0.603333333333, test acc:0.5088 === train loss:1.38707864737 train loss:1.22566771569 train loss:1.46737685907 === epoch:262, train acc:0.603333333333, test acc:0.5101 === train loss:1.441998162 train loss:1.46875455577 train loss:1.56264054148 === epoch:263, train acc:0.613333333333, test acc:0.5135 === train loss:1.47129358462 train loss:1.26720212023 train loss:1.16355445795 === epoch:264, train acc:0.603333333333, test acc:0.5129 === train loss:1.28272563927 train loss:1.35267410039 train loss:1.35719070541 === epoch:265, train acc:0.606666666667, test acc:0.5135 === train loss:1.33960807597 train loss:1.31981645059 train loss:1.3578963825 === epoch:266, train acc:0.603333333333, test acc:0.5119 === train loss:1.2804367784 train loss:1.32864914779 train loss:1.36601893026 === epoch:267, train acc:0.603333333333, test acc:0.5137 === train loss:1.19257982983 train loss:1.24919019771 train loss:1.32786138772 === epoch:268, train acc:0.603333333333, test acc:0.5099 === train loss:1.26965078289 train loss:1.23110600082 train loss:1.3335660513 === epoch:269, train acc:0.606666666667, test acc:0.5102 === train loss:1.20389045795 train loss:1.39546623657 train loss:1.46007721359 === epoch:270, train acc:0.606666666667, test acc:0.514 === train loss:1.36273262364 train loss:1.35277447267 train loss:1.48975096112 === epoch:271, train acc:0.616666666667, test acc:0.5124 === train loss:1.24456095734 train loss:1.2530207489 train loss:1.36741269668 === epoch:272, train acc:0.613333333333, test acc:0.5167 === train loss:1.28975807447 train loss:1.35224316081 train loss:1.29967874362 === epoch:273, train acc:0.616666666667, test acc:0.517 === train loss:1.33875055073 train loss:1.28441901268 train loss:1.28517641923 === epoch:274, train acc:0.62, test acc:0.5162 === train loss:1.40300904984 train loss:1.23324488767 train loss:1.41385496367 === epoch:275, train acc:0.62, test acc:0.5178 === train loss:1.25931709906 train loss:1.26968117702 train loss:1.36486369736 === epoch:276, train acc:0.62, test acc:0.5169 === train loss:1.30817166866 train loss:1.39768449655 train loss:1.20593985598 === epoch:277, train acc:0.62, test acc:0.5181 === train loss:1.20944390295 train loss:1.20337223178 train loss:1.23557334347 === epoch:278, train acc:0.62, test acc:0.5177 === train loss:1.23868616257 train loss:1.26159393256 train loss:1.47356665014 === epoch:279, train acc:0.626666666667, test acc:0.5147 === train loss:1.25255159205 train loss:1.1510921414 train loss:1.31793047077 === epoch:280, train acc:0.62, test acc:0.5177 === train loss:1.34935098405 train loss:1.29064101301 train loss:1.37222147424 === epoch:281, train acc:0.62, test acc:0.5191 === train loss:1.27013003215 train loss:1.25645554574 train loss:1.32794068666 === epoch:282, train acc:0.62, test acc:0.5191 === train loss:1.3148556935 train loss:1.27382986569 train loss:1.27012744927 === epoch:283, train acc:0.62, test acc:0.5174 === train loss:1.35066057029 train loss:1.18758666137 train loss:1.41647565804 === epoch:284, train acc:0.626666666667, test acc:0.5201 === train loss:1.28147275207 train loss:1.24108509196 train loss:1.30921670423 === epoch:285, train acc:0.616666666667, test acc:0.522 === train loss:1.15961648858 train loss:1.13990426539 train loss:1.14936212236 === epoch:286, train acc:0.62, test acc:0.5207 === train loss:1.21806437862 train loss:1.16045320027 train loss:1.27183997002 === epoch:287, train acc:0.623333333333, test acc:0.5183 === train loss:1.30555750521 train loss:1.30859592695 train loss:1.15760983889 === epoch:288, train acc:0.623333333333, test acc:0.5219 === train loss:1.12339526824 train loss:1.11748898095 train loss:1.41092413044 === epoch:289, train acc:0.62, test acc:0.5223 === train loss:1.23026608701 train loss:1.31793603508 train loss:1.20426157604 === epoch:290, train acc:0.62, test acc:0.5234 === train loss:1.14144873948 train loss:1.37352225364 train loss:1.15025121316 === epoch:291, train acc:0.626666666667, test acc:0.5238 === train loss:1.19837429099 train loss:1.20716508417 train loss:1.15094965208 === epoch:292, train acc:0.63, test acc:0.526 === train loss:1.20186827737 train loss:1.11606793807 train loss:1.28372938557 === epoch:293, train acc:0.63, test acc:0.5249 === train loss:1.15800093496 train loss:1.16876068113 train loss:1.30818307894 === epoch:294, train acc:0.633333333333, test acc:0.526 === train loss:1.31167951396 train loss:1.25839307212 train loss:1.21963776653 === epoch:295, train acc:0.643333333333, test acc:0.5335 === train loss:1.15003714573 train loss:1.22801042861 train loss:1.13823490883 === epoch:296, train acc:0.65, test acc:0.5362 === train loss:1.02108487808 train loss:1.12453035372 train loss:1.17061112045 === epoch:297, train acc:0.643333333333, test acc:0.5376 === train loss:1.17826546797 train loss:1.18602617578 train loss:1.30622122108 === epoch:298, train acc:0.643333333333, test acc:0.5348 === train loss:1.12120578785 train loss:1.19299072976 train loss:1.13540056983 === epoch:299, train acc:0.643333333333, test acc:0.5379 === train loss:1.25499381648 train loss:1.17922217266 train loss:1.1691944198 === epoch:300, train acc:0.646666666667, test acc:0.5369 === train loss:1.24592349106 train loss:1.09555758987 train loss:1.19074464714 === epoch:301, train acc:0.646666666667, test acc:0.54 === train loss:1.29322324211 train loss:1.23080112617 =============== Final Test Accuracy =============== test acc:0.5402
def shuffle_dataset(x, t):
"""データセットのシャッフルを行う
Parameters
----------
x : 訓練データ
t : 教師データ
Returns
-------
x, t : シャッフルを行った訓練データと教師データ
"""
permutation = np.random.permutation(x.shape[0])
x = x[permutation,:] if x.ndim == 2 else x[permutation,:,:,:]
t = t[permutation]
return x, t
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 高速化のため訓練データの削減
x_train = x_train[:500]
t_train = t_train[:500]
# 検証データの分離
validation_rate = 0.20
validation_num = x_train.shape[0] * validation_rate
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]
def __train(lr, weight_decay, epocs=50):
network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
output_size=10, weight_decay_lambda=weight_decay)
trainer = Trainer(network, x_train, t_train, x_val, t_val,
epochs=epocs, mini_batch_size=100,
optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)
trainer.train()
return trainer.test_acc_list, trainer.train_acc_list
# ハイパーパラメータのランダム探索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):
# 探索したハイパーパラメータの範囲を指定===============
weight_decay = 10 ** np.random.uniform(-8, -4)
lr = 10 ** np.random.uniform(-6, -2)
# ================================================
val_acc_list, train_acc_list = __train(lr, weight_decay)
print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
results_val[key] = val_acc_list
results_train[key] = train_acc_list
# グラフの描画========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0
for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)
plt.subplot(row_num, col_num, i+1)
plt.title("Best-" + str(i+1))
plt.ylim(0.0, 1.0)
if i % 5: plt.yticks([])
plt.xticks([])
x = np.arange(len(val_acc_list))
plt.plot(x, val_acc_list)
plt.plot(x, results_train[key], "--")
i += 1
if i >= graph_draw_num:
break
plt.show()
.:11: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future .:12: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future .:13: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future .:14: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
val acc:0.14 | lr:1.799831042641291e-05, weight decay:1.83778494041255e-07 val acc:0.05 | lr:0.0002347318914508403, weight decay:6.152274981702408e-08 val acc:0.72 | lr:0.006096954087946476, weight decay:2.8837250105419758e-08 val acc:0.15 | lr:0.000537954978468719, weight decay:5.878394641264319e-06 val acc:0.12 | lr:0.00016448980891706314, weight decay:1.314304106861244e-06 val acc:0.1 | lr:8.049231738261445e-05, weight decay:2.4262099041145846e-06 val acc:0.11 | lr:1.3502871720517409e-06, weight decay:1.2036817506870037e-06 val acc:0.19 | lr:0.0005901490903505406, weight decay:1.752333287592728e-05 val acc:0.17 | lr:0.00010492606760134446, weight decay:3.5827537896359874e-08 val acc:0.27 | lr:0.0006082215111986305, weight decay:1.2330059021007899e-08 val acc:0.1 | lr:2.0585419491007794e-06, weight decay:1.3271515844256667e-08 val acc:0.77 | lr:0.007314328741051734, weight decay:6.221610323334567e-05 val acc:0.34 | lr:0.0014885112166808552, weight decay:2.1864035487015072e-07 val acc:0.12 | lr:2.0416981064014347e-06, weight decay:6.031239870556507e-05 val acc:0.1 | lr:2.7307582038505023e-05, weight decay:3.1848671075102033e-05 val acc:0.15 | lr:2.035276709669739e-06, weight decay:4.33801261677729e-06 val acc:0.69 | lr:0.007978656136557872, weight decay:2.1181845167084804e-08 val acc:0.36 | lr:0.0022067061737045318, weight decay:5.432257356005885e-08 val acc:0.07 | lr:3.853895365113746e-05, weight decay:5.107802983714721e-05 val acc:0.77 | lr:0.005546405062300536, weight decay:2.244491953842928e-07 val acc:0.2 | lr:0.00024467696966955693, weight decay:1.1809479844342444e-08 val acc:0.13 | lr:2.973069492874515e-06, weight decay:2.3182482689189374e-06 val acc:0.23 | lr:0.001958960507923658, weight decay:2.5427429360304095e-05 val acc:0.18 | lr:0.0009367897210582247, weight decay:1.198279219319932e-07 val acc:0.07 | lr:1.4047887365937155e-05, weight decay:5.577247901016212e-06 val acc:0.14 | lr:0.00023218023232259645, weight decay:1.285380488762258e-08 val acc:0.16 | lr:0.0012854534763825018, weight decay:6.22892974898882e-05 val acc:0.07 | lr:3.7693088063227846e-05, weight decay:3.4697176622460025e-07 val acc:0.12 | lr:9.254775261159934e-05, weight decay:4.195261536260933e-07 val acc:0.11 | lr:0.00033554516651217094, weight decay:1.5655903081586805e-06 val acc:0.15 | lr:0.0005684466072664272, weight decay:8.191101644551462e-06 val acc:0.14 | lr:0.0001918173312108617, weight decay:4.220879161487208e-08 val acc:0.18 | lr:0.0012728129972591367, weight decay:3.0940193198532466e-06 val acc:0.56 | lr:0.005435330538585565, weight decay:1.11889668740693e-08 val acc:0.03 | lr:1.8702383864207209e-06, weight decay:2.691985935193073e-05 val acc:0.18 | lr:0.00012912367358681347, weight decay:2.230314066915655e-06 val acc:0.18 | lr:0.0011710826048938314, weight decay:3.707969136914042e-05 val acc:0.76 | lr:0.007954156922587819, weight decay:8.113842039183665e-05 val acc:0.1 | lr:2.2189263098489247e-05, weight decay:1.266444417960917e-06 val acc:0.11 | lr:3.95013497639435e-06, weight decay:1.3889503617062258e-07 val acc:0.07 | lr:2.84964816157627e-06, weight decay:7.873261365055332e-07 val acc:0.07 | lr:6.791751737541861e-05, weight decay:1.1666232402530862e-07 val acc:0.12 | lr:4.7605110774414986e-05, weight decay:4.701289026463789e-07 val acc:0.11 | lr:1.613959255760313e-05, weight decay:1.4474311342819413e-07 val acc:0.08 | lr:1.9084937013001158e-05, weight decay:1.1580359946842445e-08 val acc:0.06 | lr:1.1309654576294414e-06, weight decay:8.046637985249008e-07 val acc:0.03 | lr:0.00016809676629749238, weight decay:6.684320372588282e-08 val acc:0.15 | lr:2.6284254582504244e-06, weight decay:1.7712046750623615e-05 val acc:0.16 | lr:2.5240862936759203e-06, weight decay:1.7453400948168694e-07 val acc:0.1 | lr:1.1158786610524581e-05, weight decay:4.1154975866773e-08 val acc:0.12 | lr:1.547685797530406e-05, weight decay:1.569480965439299e-07 val acc:0.55 | lr:0.004131630902309629, weight decay:1.0886928474624086e-08 val acc:0.59 | lr:0.0034377009625794937, weight decay:3.711910367116388e-06 val acc:0.13 | lr:0.0008904922751777309, weight decay:1.1743775191245708e-08 val acc:0.08 | lr:2.158803706372628e-05, weight decay:3.1611583940055314e-07 val acc:0.26 | lr:0.0026125619375359215, weight decay:2.6524829769885797e-07 val acc:0.39 | lr:0.0013423928947151845, weight decay:2.279218062293639e-05 val acc:0.07 | lr:1.9002323147660616e-06, weight decay:7.325596990703465e-08 val acc:0.11 | lr:1.0648049458192806e-05, weight decay:1.0503886281971934e-08 val acc:0.05 | lr:1.155051836331536e-05, weight decay:3.724438766693503e-08 val acc:0.1 | lr:3.6502819162719045e-05, weight decay:4.580456535431928e-06 val acc:0.1 | lr:0.0004027230904781737, weight decay:4.624566296444368e-07 val acc:0.1 | lr:0.00023972576360574787, weight decay:1.0888731070724182e-06 val acc:0.09 | lr:4.561877303150583e-05, weight decay:1.074280722217715e-06 val acc:0.1 | lr:0.00016611550855609631, weight decay:4.874471865684263e-08 val acc:0.28 | lr:0.002022949093719039, weight decay:8.610289954328187e-07 val acc:0.07 | lr:3.1568345661005386e-05, weight decay:4.881558365923912e-08 val acc:0.15 | lr:0.00031082988319757415, weight decay:9.305666819798751e-05 val acc:0.13 | lr:0.0005578967857041792, weight decay:1.0455604777430316e-06 val acc:0.07 | lr:2.3210352685796004e-06, weight decay:2.316184047929858e-05 val acc:0.04 | lr:5.397767615145739e-06, weight decay:2.9720550443668605e-05 val acc:0.15 | lr:0.00028073761943585627, weight decay:2.3303993788433552e-07 val acc:0.08 | lr:6.500794855230586e-06, weight decay:8.163747423298429e-08 val acc:0.13 | lr:6.716395506786275e-05, weight decay:4.2812435460540633e-07 val acc:0.05 | lr:8.612898100000524e-05, weight decay:1.385570360005581e-06 val acc:0.17 | lr:0.00014666983480335065, weight decay:1.8597783261658382e-06 val acc:0.33 | lr:0.002340700358489932, weight decay:3.826747082331894e-05 val acc:0.15 | lr:0.00043884273977923234, weight decay:5.3010740583566995e-05 val acc:0.35 | lr:0.0017426894072650968, weight decay:2.0552950835899585e-06 val acc:0.04 | lr:3.4007550973010675e-06, weight decay:1.2293547098157056e-06 val acc:0.1 | lr:2.0928485346802797e-05, weight decay:4.046823185041274e-07 val acc:0.3 | lr:0.0011420214429134596, weight decay:9.834080999472783e-07 val acc:0.28 | lr:0.0020373091497788015, weight decay:1.8208579179541355e-05 val acc:0.09 | lr:1.2662784890230906e-05, weight decay:1.1954731974341795e-05 val acc:0.12 | lr:1.0774019168468301e-05, weight decay:4.231135927240067e-05 val acc:0.83 | lr:0.007955421023163349, weight decay:1.0179821040366434e-06 val acc:0.1 | lr:1.8461463359567466e-05, weight decay:5.11546180963856e-06 val acc:0.11 | lr:1.025610398022961e-06, weight decay:1.1519228022943636e-08 val acc:0.09 | lr:4.355868861135961e-06, weight decay:7.402958626649398e-05 val acc:0.11 | lr:1.7976552837758237e-05, weight decay:3.121220660884959e-05 val acc:0.15 | lr:0.0008823477486653762, weight decay:4.279730409066619e-07 val acc:0.08 | lr:0.00019115989393658422, weight decay:4.911706483065088e-08 val acc:0.17 | lr:8.349245371612928e-06, weight decay:5.556885007143987e-07 val acc:0.21 | lr:0.0009229338974106432, weight decay:4.2219636978150777e-07 val acc:0.12 | lr:1.9085241976115366e-06, weight decay:5.939819864892863e-06 val acc:0.1 | lr:2.7219447274445154e-05, weight decay:2.947521095934162e-06 val acc:0.13 | lr:0.00031152967951795114, weight decay:3.653708844484683e-06 val acc:0.14 | lr:2.30721778848321e-05, weight decay:1.3037807375881878e-05 val acc:0.05 | lr:3.101814696159778e-06, weight decay:4.2643538930905157e-07 val acc:0.04 | lr:1.788154005514171e-06, weight decay:1.1283063237907293e-07 =========== Hyper-Parameter Optimization Result =========== Best-1(val acc:0.83) | lr:0.007955421023163349, weight decay:1.0179821040366434e-06 Best-2(val acc:0.77) | lr:0.007314328741051734, weight decay:6.221610323334567e-05 Best-3(val acc:0.77) | lr:0.005546405062300536, weight decay:2.244491953842928e-07 Best-4(val acc:0.76) | lr:0.007954156922587819, weight decay:8.113842039183665e-05 Best-5(val acc:0.72) | lr:0.006096954087946476, weight decay:2.8837250105419758e-08 Best-6(val acc:0.69) | lr:0.007978656136557872, weight decay:2.1181845167084804e-08 Best-7(val acc:0.59) | lr:0.0034377009625794937, weight decay:3.711910367116388e-06 Best-8(val acc:0.56) | lr:0.005435330538585565, weight decay:1.11889668740693e-08 Best-9(val acc:0.55) | lr:0.004131630902309629, weight decay:1.0886928474624086e-08 Best-10(val acc:0.39) | lr:0.0013423928947151845, weight decay:2.279218062293639e-05 Best-11(val acc:0.36) | lr:0.0022067061737045318, weight decay:5.432257356005885e-08 Best-12(val acc:0.35) | lr:0.0017426894072650968, weight decay:2.0552950835899585e-06 Best-13(val acc:0.34) | lr:0.0014885112166808552, weight decay:2.1864035487015072e-07 Best-14(val acc:0.33) | lr:0.002340700358489932, weight decay:3.826747082331894e-05 Best-15(val acc:0.3) | lr:0.0011420214429134596, weight decay:9.834080999472783e-07 Best-16(val acc:0.28) | lr:0.002022949093719039, weight decay:8.610289954328187e-07 Best-17(val acc:0.28) | lr:0.0020373091497788015, weight decay:1.8208579179541355e-05 Best-18(val acc:0.27) | lr:0.0006082215111986305, weight decay:1.2330059021007899e-08 Best-19(val acc:0.26) | lr:0.0026125619375359215, weight decay:2.6524829769885797e-07 Best-20(val acc:0.23) | lr:0.001958960507923658, weight decay:2.5427429360304095e-05