import numpy as np
def mean_squared_error(y, t):
return 0.5 * np.sum((y-t)**2)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
mean_squared_error(np.array(y), np.array(t))
0.097500000000000031
import numpy as np
def cross_entropy_error(y, t):
delta = 1e-7
return -np.sum(t * np.log(y + delta))
cross_entropy_error(np.array(y), np.array(t))
0.51082545709933802
from mymodule.mnist import load_mnist
(x_train, t_train), (x_test, t_test)= load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
(60000, 784) (60000, 10)
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
print(batch_mask)
[12725 15614 4447 22004 41049 44449 58596 21660 47599 7767]
import numpy as np
def cross_entropy_error(y, t): #バッチ対応版
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
# 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
if t.size == y.size:
t = t.argmax(axis=1)
batch_size = y.shape[0]
return -np.sum(np.log(y[np.arange(batch_size), t])) / batch_size
cross_entropy_error(np.array(y), np.array(t))
0.51082545709933802
import numpy as np
def numerical_diff(f, x):
h = 1e-4 # 0.0001
return (f(x+h) - f(x-h)) / (2*h)
# coding: utf-8
import matplotlib.pylab as plt
%matplotlib inline
def function_1(x):
return 0.01*x**2 + 0.1*x
def tangent_line(f, x):
d = numerical_diff(f, x)
print(d)
y = f(x) - d*x
return lambda t: d*t + y
x = np.arange(0.0, 20.0, 0.1)
y = function_1(x)
plt.xlabel("x")
plt.ylabel("f(x)")
tf = tangent_line(function_1, 5)
y2 = tf(x)
plt.plot(x, y)
plt.plot(x, y2)
plt.show()
0.1999999999990898
import numpy as np
def _numerical_gradient_no_batch(f, x):
h = 1e-4 # 0.0001
grad = np.zeros_like(x)
for idx in range(x.size):
tmp_val = x[idx]
x[idx] = float(tmp_val) + h
fxh1 = f(x) # f(x+h)
x[idx] = tmp_val - h
fxh2 = f(x) # f(x-h)
grad[idx] = (fxh1 - fxh2) / (2*h)
x[idx] = tmp_val # 値を元に戻す
return grad
import numpy as np
def function_2(x):
if x.ndim == 1:
return np.sum(x**2)
else:
return np.sum(x**2, axis=1)
_numerical_gradient_no_batch(function_2, np.array([3.0, 4.0]))
array([ 6., 8.])
import numpy as np
def numerical_gradient(f, X):
if X.ndim == 1:
return _numerical_gradient_no_batch(f, X)
else:
grad = np.zeros_like(X)
for idx, x in enumerate(X):
grad[idx] = _numerical_gradient_no_batch(f, x)
return grad
import numpy as np
def tangent_line(f, x):
d = numerical_gradient(f, x)
print(d)
y = f(x) - d*x
return lambda t: d*t + y
# coding: utf-8
# cf.http://d.hatena.ne.jp/white_wheels/20100327/p3
import numpy as np
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
x0 = np.arange(-2, 2.5, 0.25)
x1 = np.arange(-2, 2.5, 0.25)
X, Y = np.meshgrid(x0, x1)
X = X.flatten()
Y = Y.flatten()
grad = numerical_gradient(function_2, np.array([X, Y]) )
plt.figure()
plt.quiver(X, Y, -grad[0], -grad[1], angles="xy",color="#666666")#,headwidth=10,scale=40,color="#444444")
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.xlabel('x0')
plt.ylabel('x1')
plt.grid()
plt.legend()
plt.draw()
plt.show()
/Users/kot/miniconda2/envs/python3/lib/python3.6/site-packages/matplotlib/axes/_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots. warnings.warn("No labelled objects found. "
import numpy as np
def gradient_descent(f, init_x, lr=0.01, step_num=100):
x = init_x
x_history = []
for i in range(step_num):
x_history.append( x.copy() )
grad = numerical_gradient(f, x)
x -= lr * grad
return x, np.array(x_history)
# coding: utf-8
import numpy as np
import matplotlib.pylab as plt
#from gradient_2d import numerical_gradient
%matplotlib inline
def function_2(x):
return x[0]**2 + x[1]**2
init_x = np.array([-3.0, 4.0])
lr = 0.1
step_num = 20
x, x_history = gradient_descent(function_2, init_x, lr=lr, step_num=step_num)
plt.plot( [-5, 5], [0,0], '--b')
plt.plot( [0,0], [-5, 5], '--b')
plt.plot(x_history[:,0], x_history[:,1], 'o')
plt.xlim(-3.5, 3.5)
plt.ylim(-4.5, 4.5)
plt.xlabel("X0")
plt.ylabel("X1")
plt.show()
import numpy as np
class simpleNet:
def __init__(self):
self.W = np.random.randn(2,3)
def predict(self, x):
return np.dot(x, self.W)
def loss(self, x, t):
z = self.predict(x)
y = softmax(z)
loss = cross_entropy_error(y, t)
return loss
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定
import numpy as np
from mymodule.functions import softmax, cross_entropy_error
from mymodule.gradient import numerical_gradient
x = np.array([0.6, 0.9])
t = np.array([0, 0, 1])
net = simpleNet()
f = lambda w: net.loss(x, t)
dW = numerical_gradient(f, net.W)
print(dW)
[[ 0.31623088 0.04891163 -0.36514251] [ 0.47434632 0.07336744 -0.54771376]]
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定
from mymodule.functions import *
from mymodule.gradient import numerical_gradient
class TwoLayerNet:
def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
# 重みの初期化
self.params = {}
self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
self.params['b1'] = np.zeros(hidden_size)
self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
self.params['b2'] = np.zeros(output_size)
def predict(self, x):
W1, W2 = self.params['W1'], self.params['W2']
b1, b2 = self.params['b1'], self.params['b2']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
y = softmax(a2)
return y
# x:入力データ, t:教師データ
def loss(self, x, t):
y = self.predict(x)
return cross_entropy_error(y, t)
def accuracy(self, x, t):
y = self.predict(x)
y = np.argmax(y, axis=1)
t = np.argmax(t, axis=1)
accuracy = np.sum(y == t) / float(x.shape[0])
return accuracy
# x:入力データ, t:教師データ
def numerical_gradient(self, x, t):
loss_W = lambda W: self.loss(x, t)
grads = {}
grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
return grads
def gradient(self, x, t):
W1, W2 = self.params['W1'], self.params['W2']
b1, b2 = self.params['b1'], self.params['b2']
grads = {}
batch_num = x.shape[0]
# forward
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
y = softmax(a2)
# backward
dy = (y - t) / batch_num
grads['W2'] = np.dot(z1.T, dy)
grads['b2'] = np.sum(dy, axis=0)
da1 = np.dot(dy, W2.T)
dz1 = sigmoid_grad(a1) * da1
grads['W1'] = np.dot(x.T, dz1)
grads['b1'] = np.sum(dz1, axis=0)
return grads
net = TwoLayerNet(input_size=784, hidden_size=100, output_size=10)
x = np.random.rand(100, 784)
y = net.predict(x)
print(y)
[[ 0.10367082 0.09479027 0.09646175 0.0976099 0.09643014 0.10534564 0.11351447 0.09714606 0.09772518 0.09730575] [ 0.10370514 0.09551623 0.09629093 0.09727947 0.09620149 0.10544123 0.11376555 0.09732914 0.09752802 0.09694279] [ 0.10354438 0.094939 0.09638792 0.0974336 0.09647522 0.10560926 0.11352656 0.09747596 0.09753358 0.09707453] [ 0.10325765 0.09473589 0.09619979 0.09781487 0.09625612 0.10613874 0.11360267 0.09727314 0.09754611 0.09717503] [ 0.10336792 0.09490557 0.09612918 0.0974013 0.09657672 0.10587389 0.11353339 0.09716664 0.09804558 0.09699981] [ 0.1034402 0.09517651 0.09613979 0.09748063 0.09639921 0.10576113 0.11371928 0.097294 0.09749499 0.09709426] [ 0.10340416 0.09491995 0.09612156 0.09781727 0.09637837 0.10582093 0.11315641 0.09729886 0.09782377 0.09725872] [ 0.1035974 0.09518469 0.09615298 0.09735027 0.09623511 0.10570848 0.11384599 0.09747316 0.0976472 0.09680472] [ 0.10353392 0.09507303 0.09626424 0.09744944 0.09638031 0.10613693 0.11338786 0.09747755 0.09734934 0.09694738] [ 0.10338468 0.09470548 0.09648883 0.09785831 0.09615428 0.10559057 0.11400998 0.09732413 0.09772884 0.09675488] [ 0.10360109 0.09515098 0.09616628 0.09759359 0.09642112 0.10568466 0.11329511 0.09738794 0.09786947 0.09682977] [ 0.10340375 0.09483135 0.09666576 0.09746198 0.09629309 0.10541206 0.11347598 0.09755973 0.09785035 0.09704596] [ 0.10361427 0.09490067 0.0966257 0.09728309 0.09601373 0.10595918 0.11370983 0.09698813 0.09787718 0.09702822] [ 0.10348479 0.09472268 0.09630097 0.09754206 0.09658455 0.10542256 0.11384967 0.09740465 0.09770777 0.09698029] [ 0.10357714 0.09510523 0.09603122 0.09762087 0.09649619 0.10575999 0.11364116 0.09711378 0.09760841 0.09704601] [ 0.10336456 0.09476365 0.09660215 0.09759952 0.09591588 0.10524624 0.11430279 0.09716405 0.09778399 0.09725717] [ 0.10356384 0.09498307 0.09626227 0.09734968 0.09639528 0.10571604 0.11351518 0.09724702 0.09796428 0.09700333] [ 0.10294609 0.09452934 0.09629462 0.09758262 0.0963543 0.10602685 0.11400695 0.09734129 0.09784658 0.09707136] [ 0.10354164 0.09477884 0.09651979 0.09744512 0.09584932 0.10561321 0.11404843 0.09754357 0.09750497 0.0971551 ] [ 0.1036714 0.09492617 0.09647864 0.09744065 0.09616188 0.10598962 0.11361011 0.09705935 0.09761079 0.09705139] [ 0.1032295 0.09454921 0.09617835 0.09764098 0.09636909 0.10639753 0.11352477 0.09700204 0.09783706 0.09727146] [ 0.10387642 0.09464837 0.09649467 0.09764721 0.09658126 0.1055753 0.11311828 0.09724764 0.09778955 0.09702129] [ 0.10365923 0.09482296 0.09643405 0.09758713 0.09606482 0.10544325 0.11391418 0.09729236 0.09796254 0.09681947] [ 0.10319734 0.09495189 0.09637599 0.09749677 0.09609444 0.10552326 0.11385384 0.0972906 0.0979108 0.09730508] [ 0.10343785 0.09516412 0.09640186 0.09723803 0.09618114 0.10590221 0.11357532 0.09717002 0.09792725 0.0970022 ] [ 0.10315931 0.09498711 0.09625506 0.09727253 0.09645229 0.10603379 0.11370511 0.09738983 0.09785441 0.09689056] [ 0.10350202 0.09496208 0.09640557 0.09725681 0.09648072 0.10584358 0.11379059 0.09731626 0.09765679 0.09678558] [ 0.10341538 0.0948994 0.09627071 0.09779746 0.09609296 0.10576032 0.11336566 0.09744026 0.09769157 0.09726627] [ 0.10349762 0.09469195 0.09652979 0.09743318 0.09611361 0.10588134 0.11392444 0.09731908 0.09749386 0.09711512] [ 0.10338629 0.09479871 0.096444 0.09751646 0.09667524 0.10558961 0.11371138 0.09706655 0.09781694 0.09699483] [ 0.10327229 0.0948389 0.09616683 0.09742139 0.0960045 0.10611934 0.11387717 0.09735791 0.09785796 0.09708371] [ 0.1038353 0.09477209 0.09645835 0.09747089 0.09605933 0.10558803 0.11348905 0.09735937 0.09779245 0.09717514] [ 0.10320789 0.09485175 0.0963211 0.09753573 0.09611724 0.10573106 0.11342218 0.09743097 0.09812749 0.0972546 ] [ 0.10369861 0.09508061 0.09623248 0.09768131 0.09601444 0.10564451 0.11343973 0.09737424 0.09751739 0.09731668] [ 0.10324542 0.09487752 0.09581459 0.09759 0.09684731 0.10626929 0.11365564 0.09716781 0.09763532 0.0968971 ] [ 0.10311246 0.09492267 0.09609107 0.09730746 0.09650677 0.10619844 0.11369789 0.09695014 0.09804874 0.09716436] [ 0.10349557 0.09476622 0.09639998 0.09745626 0.09615033 0.10584448 0.11387774 0.0972914 0.09772466 0.09699336] [ 0.10349094 0.09509643 0.09623526 0.09770724 0.09643429 0.10598776 0.11320778 0.09707076 0.09762132 0.09714821] [ 0.10376536 0.09512676 0.09636964 0.09727872 0.0962042 0.10555427 0.11376841 0.09712984 0.09772269 0.09708011] [ 0.10349381 0.09489317 0.09649171 0.09737264 0.09642755 0.10584646 0.11390865 0.09717209 0.0975431 0.09685082] [ 0.10333053 0.09478807 0.09624852 0.09779008 0.09621812 0.10609636 0.11368551 0.09706436 0.09773103 0.09704743] [ 0.1036127 0.09498321 0.09645975 0.09758257 0.0960892 0.10534037 0.11345525 0.09723384 0.09788105 0.09736207] [ 0.10344113 0.09460839 0.09652612 0.09742494 0.09679146 0.10562607 0.11371631 0.09725209 0.097624 0.09698949] [ 0.1034596 0.09486862 0.09615301 0.09756254 0.09634269 0.10614634 0.113479 0.09700646 0.0977238 0.09725794] [ 0.10312288 0.09492344 0.0963062 0.09753016 0.09638465 0.10548096 0.11348128 0.09736087 0.09788821 0.09752136] [ 0.10349264 0.09507148 0.09633969 0.09759861 0.0965389 0.10541007 0.11358626 0.09737106 0.0974675 0.09712379] [ 0.1033773 0.09476423 0.09628093 0.09758793 0.09606628 0.10526225 0.11420288 0.09767378 0.09780916 0.09697528] [ 0.10336973 0.09486639 0.09655025 0.09768554 0.09692578 0.10525455 0.1133887 0.09686335 0.0980172 0.09707851] [ 0.10342562 0.09502738 0.09621989 0.0972588 0.09632021 0.1059962 0.1133399 0.09720664 0.09809828 0.09710707] [ 0.10341537 0.09498616 0.09622123 0.09736885 0.09609007 0.10583516 0.11366317 0.09747833 0.09798774 0.09695392] [ 0.10315584 0.09478486 0.09623069 0.09743434 0.09637473 0.10606327 0.11361294 0.09738823 0.0978391 0.097116 ] [ 0.10383353 0.09479182 0.09628422 0.09749407 0.09633608 0.10530841 0.11381501 0.09744567 0.09786254 0.09682865] [ 0.10325796 0.09504546 0.09655409 0.09740719 0.09625337 0.10567375 0.11342283 0.09742217 0.09771626 0.09724693] [ 0.10355474 0.09499022 0.09642584 0.09741508 0.0962711 0.10568786 0.11340847 0.09704245 0.09831978 0.09688447] [ 0.10367996 0.09515579 0.09609429 0.0976486 0.09623629 0.10571961 0.11363696 0.09698581 0.09775637 0.09708634] [ 0.10322208 0.09530825 0.09652657 0.09728007 0.09624246 0.10591789 0.11335445 0.09693941 0.09791559 0.09729322] [ 0.1030599 0.09485522 0.09632159 0.0975402 0.09621876 0.1064606 0.11352251 0.09720691 0.09765498 0.09715932] [ 0.10376064 0.09501073 0.09633641 0.09754959 0.09612713 0.10548621 0.11320968 0.09738981 0.09775903 0.09737077] [ 0.10346409 0.09487982 0.09630544 0.09748226 0.09666193 0.10564112 0.11354367 0.09734161 0.09803292 0.09664716] [ 0.10335378 0.09478458 0.09648526 0.09769955 0.09627725 0.10562057 0.11397851 0.09712441 0.09772556 0.09695052] [ 0.10345836 0.09500874 0.09644083 0.09732909 0.09637779 0.106219 0.11338509 0.09694109 0.09759298 0.09724703] [ 0.10347906 0.09514347 0.09607919 0.09770782 0.09630105 0.106161 0.11344711 0.09704503 0.09766069 0.09697558] [ 0.10337142 0.09489003 0.09637319 0.09757253 0.09596185 0.10588921 0.11377502 0.09720734 0.097917 0.09704241] [ 0.10330472 0.09481179 0.09635099 0.09778429 0.09616546 0.10571504 0.11371405 0.09739848 0.09754902 0.09720615] [ 0.10337298 0.09469182 0.09645229 0.09773187 0.09598046 0.10553668 0.11375764 0.09721081 0.09815399 0.09711145] [ 0.1034032 0.09493167 0.09648999 0.09758202 0.09611973 0.10528555 0.11382616 0.09730647 0.09795515 0.09710006] [ 0.10346023 0.09502214 0.09652107 0.09725638 0.09600719 0.10527969 0.11391226 0.09738794 0.09790715 0.09724594] [ 0.10339058 0.09478134 0.09626293 0.09770721 0.09624394 0.10602452 0.1137558 0.09739568 0.09756598 0.09687201] [ 0.10327832 0.09504485 0.09641109 0.09749443 0.09658213 0.10547431 0.11361574 0.09714095 0.09755088 0.09740731] [ 0.10305818 0.0947249 0.09646573 0.09780568 0.09616814 0.10583149 0.11355371 0.09744565 0.09781027 0.09713626] [ 0.10353495 0.09481374 0.09640271 0.09754861 0.09638774 0.1059661 0.11337743 0.09713818 0.0979343 0.09689623] [ 0.10355293 0.09471916 0.09642364 0.09748893 0.09618247 0.10592444 0.11344193 0.09704662 0.09765868 0.09756121] [ 0.10351238 0.09471753 0.09632352 0.09747202 0.09645437 0.1057288 0.1137237 0.0973746 0.09749527 0.09719781] [ 0.10331111 0.09491301 0.09648866 0.09734429 0.09642937 0.10586373 0.11381208 0.09722813 0.0975545 0.09705512] [ 0.10360451 0.09528556 0.09613623 0.09750235 0.09641937 0.10490654 0.11350991 0.09746855 0.09779232 0.09737467] [ 0.1032128 0.09487487 0.0963583 0.09739142 0.09638361 0.10597095 0.11367006 0.09720337 0.09789432 0.0970403 ] [ 0.10349844 0.09461495 0.09654407 0.0974198 0.09648971 0.10552912 0.11362251 0.09723959 0.09796137 0.09708044] [ 0.10347654 0.09479432 0.09634505 0.09756813 0.09607421 0.10580021 0.11356293 0.09742828 0.09780443 0.09714589] [ 0.10341529 0.09480483 0.09642966 0.0974578 0.09606317 0.10591811 0.11384227 0.09748034 0.09761369 0.09697483] [ 0.10377319 0.09499628 0.0960313 0.09764188 0.09595742 0.10575969 0.11400118 0.09725531 0.09785787 0.09672588] [ 0.10353852 0.09492572 0.09621956 0.09748627 0.09647797 0.10557523 0.11371262 0.09758104 0.09776953 0.09671354] [ 0.10308612 0.09495033 0.09638194 0.0975346 0.09622243 0.10569586 0.11358889 0.09723681 0.09779139 0.09751163] [ 0.10330882 0.0947151 0.09628401 0.09754497 0.09653746 0.10605866 0.11334427 0.09709477 0.09793274 0.0971792 ] [ 0.10354408 0.09469209 0.09657276 0.09775958 0.09642098 0.10516106 0.11385738 0.09694187 0.09784235 0.09720785] [ 0.10347216 0.09493807 0.09641996 0.09731468 0.09645172 0.10582042 0.11366587 0.09698864 0.09762004 0.09730846] [ 0.10376497 0.09498735 0.09631985 0.09759573 0.09651505 0.10532649 0.11315531 0.0975482 0.0978958 0.09689124] [ 0.10363001 0.09471302 0.09623423 0.09763708 0.09670594 0.10586162 0.11335767 0.09743186 0.09773125 0.09669732] [ 0.10375791 0.09470236 0.0963705 0.09738316 0.09663371 0.10559649 0.11363719 0.09716954 0.09769884 0.0970503 ] [ 0.10362858 0.0948281 0.09629959 0.09745992 0.09633315 0.10519324 0.11393622 0.09749798 0.09764786 0.09717534] [ 0.10339347 0.09491055 0.09655345 0.09740026 0.09627538 0.10590578 0.11332334 0.09702015 0.09813027 0.09708735] [ 0.10329989 0.09515331 0.09585586 0.09760662 0.09636141 0.10596281 0.11316471 0.0975304 0.09772504 0.09733995] [ 0.10314086 0.09460419 0.09645057 0.09768193 0.09622171 0.10573139 0.1139367 0.09727201 0.09769445 0.09726619] [ 0.10355443 0.09465076 0.09627149 0.09767112 0.09647642 0.10565983 0.1136114 0.09716158 0.09774738 0.09719559] [ 0.10338706 0.09465473 0.0964109 0.09772361 0.09627296 0.10572771 0.11389154 0.09736289 0.09774936 0.09681923] [ 0.10373324 0.09493804 0.0963682 0.09727249 0.09654144 0.10577225 0.11384533 0.09704652 0.0976372 0.09684531] [ 0.10367002 0.09484681 0.09638485 0.09759012 0.0959509 0.10584492 0.11376335 0.09753206 0.09760135 0.09681562] [ 0.10359829 0.0949266 0.09615354 0.09755351 0.09632394 0.10528007 0.11366501 0.09745908 0.09790845 0.0971315 ] [ 0.10354437 0.09469433 0.09623923 0.0975579 0.09628636 0.10600591 0.11356786 0.09741278 0.09780217 0.09688908] [ 0.10345099 0.09521102 0.09632521 0.09757825 0.09617071 0.10587436 0.11371838 0.09729772 0.09751392 0.09685944] [ 0.10338018 0.09507883 0.09649727 0.09723977 0.0963523 0.10577775 0.1135283 0.09720346 0.09767714 0.097265 ]]
x = np.random.rand(100, 784)
t = np.random.rand(100, 10)
grads = net.numerical_gradient(x, t)
print(grads)
{'W1': array([[ 1.61360778e-04, 1.20043919e-04, 6.89253987e-05, ..., 3.23179750e-05, -3.51339979e-05, 1.25530821e-04], [ 5.82171023e-05, 2.22927945e-04, 1.72836723e-05, ..., -6.01102079e-05, -1.67789582e-05, 1.79360786e-04], [ 7.19693194e-05, 1.24935289e-04, 4.59431737e-05, ..., -1.20054715e-04, -1.75628645e-05, 1.29623725e-04], ..., [ 1.71491237e-04, 1.42138374e-04, 7.33165240e-05, ..., -5.84959436e-05, -1.82599513e-05, 1.02064832e-04], [ 4.22377977e-05, 1.97750258e-04, 5.48633383e-05, ..., -8.78759243e-05, 1.50891566e-05, 2.29489372e-04], [ 1.76003505e-04, 5.01910957e-05, 1.11402585e-04, ..., -4.23339319e-05, -5.50190982e-05, 1.33508946e-04]]), 'b1': array([ 2.24106820e-04, 3.18442770e-04, 8.50670867e-05, 5.15726128e-05, 1.02923614e-04, 1.23673527e-04, -2.03893740e-04, -9.34253785e-05, -4.90468555e-06, 1.17784340e-04, -2.67621352e-04, 5.89559646e-05, -4.94109309e-05, 1.58181519e-04, 1.43716639e-05, -4.51203075e-04, 2.98332714e-05, 2.39456108e-04, 1.60761435e-04, -1.15104593e-05, -1.06565481e-04, -2.68417035e-04, 5.32822564e-05, 3.35749737e-04, -2.03405981e-05, 9.02841957e-05, -7.49667461e-05, -6.12803253e-05, -5.03459674e-05, 3.18341800e-04, -7.19458626e-05, -5.69749137e-05, 2.79938210e-04, -1.04844533e-05, -2.59335096e-04, -8.71692873e-05, 1.79807595e-04, 1.90870608e-04, 4.01982112e-04, 6.79056811e-08, 2.16687521e-04, 4.76192845e-04, 4.47310593e-04, 6.97533764e-05, 7.34495109e-05, -6.56053523e-05, -1.90512670e-04, -1.72305248e-04, 1.85869302e-04, 3.43615036e-04, 1.49049997e-04, -2.86456894e-04, -1.62983658e-04, -2.78321299e-05, 8.64366667e-05, -7.08082681e-05, 3.54588574e-04, 1.05170168e-04, -2.32367834e-04, -2.22743086e-04, -1.46212309e-05, -6.71095712e-05, 1.10315799e-04, 6.82165191e-05, -2.39380915e-05, 4.41941439e-05, -1.90462253e-04, 3.64937365e-04, -3.35342909e-05, -9.05753561e-05, -2.79382717e-05, 2.60483277e-04, -1.33093734e-04, -4.91282925e-05, 1.92187377e-06, 2.61662332e-04, 4.19741797e-05, 8.74032602e-05, -2.99454479e-04, -1.71571917e-04, 2.14879501e-04, -5.76984793e-05, 9.46916323e-05, -5.60583913e-05, 1.62453717e-04, -3.77032472e-05, 2.44215359e-05, 3.72522380e-04, 4.19225141e-04, -2.15701255e-04, 3.44129469e-04, 3.04946757e-05, 2.32703325e-04, 9.55515311e-05, -1.93453327e-04, 7.32164551e-05, 2.75084733e-05, -8.79287843e-05, -5.56448909e-05, 2.72366791e-04]), 'W2': array([[ 0.00737691, -0.00234876, -0.00712771, -0.01242792, -0.01708929, 0.01348764, 0.01770026, -0.02264042, 0.00969058, 0.01337871], [ 0.0057193 , -0.00224286, -0.00596033, -0.01074008, -0.01631053, 0.01217038, 0.01539342, -0.01932411, 0.00830708, 0.01298773], [ 0.00652714, -0.00206584, -0.00808272, -0.01176395, -0.01723458, 0.01355534, 0.01790998, -0.02236431, 0.00867283, 0.01484611], [ 0.00648029, -0.00303633, -0.00700587, -0.01155075, -0.01634792, 0.01311982, 0.01715591, -0.02257925, 0.00908888, 0.01467522], [ 0.00580638, -0.0030603 , -0.00580772, -0.01026981, -0.01536475, 0.01089204, 0.01605352, -0.02009735, 0.0083534 , 0.0134946 ], [ 0.00679003, -0.00346891, -0.00664362, -0.01137419, -0.01727166, 0.0133839 , 0.01825546, -0.02292853, 0.00948767, 0.01376986], [ 0.00716354, -0.00300631, -0.0067067 , -0.01218126, -0.01761017, 0.01367184, 0.01786411, -0.02184053, 0.00838641, 0.01425907], [ 0.00678094, -0.00308018, -0.00745565, -0.01250181, -0.01795605, 0.01359541, 0.01943786, -0.02280459, 0.01003884, 0.01394525], [ 0.00686315, -0.00282591, -0.00617882, -0.01134636, -0.01590928, 0.01226682, 0.0165885 , -0.02079243, 0.0083209 , 0.01301344], [ 0.00605012, -0.00177943, -0.0064378 , -0.01100167, -0.01439341, 0.01262643, 0.0152999 , -0.01887849, 0.00741275, 0.0111016 ], [ 0.007876 , -0.00203673, -0.00879865, -0.01168777, -0.01833277, 0.0148216 , 0.01917927, -0.0247063 , 0.00893194, 0.01475342], [ 0.00791362, -0.00307924, -0.00787885, -0.01255228, -0.01911973, 0.01396693, 0.01818769, -0.02260883, 0.01043137, 0.01473932], [ 0.00639648, -0.00338863, -0.00616536, -0.01104833, -0.01614581, 0.01253176, 0.01746579, -0.02076832, 0.00853864, 0.01258379], [ 0.00625309, -0.00350742, -0.00638378, -0.01153453, -0.01527341, 0.01310967, 0.01628827, -0.02110963, 0.00907467, 0.01308308], [ 0.00598906, -0.00244463, -0.00517122, -0.01090397, -0.01612932, 0.01194564, 0.01649642, -0.01981781, 0.00777653, 0.01225929], [ 0.00702585, -0.00372564, -0.00687548, -0.01016654, -0.01741897, 0.01251118, 0.01753639, -0.02154513, 0.00911786, 0.01354048], [ 0.00742588, -0.00331273, -0.00677166, -0.01113657, -0.01787326, 0.0134557 , 0.01757056, -0.0223652 , 0.00836697, 0.01464031], [ 0.00703009, -0.0029443 , -0.00721406, -0.01166616, -0.01701717, 0.01333955, 0.01754151, -0.02130358, 0.00859176, 0.01364234], [ 0.00661208, -0.00277842, -0.00710611, -0.01225373, -0.01715511, 0.01341866, 0.01778235, -0.02210998, 0.00919634, 0.01439391], [ 0.0067802 , -0.00273368, -0.00646756, -0.01063647, -0.01724484, 0.01265014, 0.01582056, -0.02158993, 0.00996504, 0.01345654], [ 0.00685267, -0.00265121, -0.00708384, -0.00897855, -0.01648806, 0.01219267, 0.01464518, -0.02089057, 0.00967543, 0.01272628], [ 0.00638067, -0.00301055, -0.00609944, -0.01174783, -0.01805306, 0.0133887 , 0.01767998, -0.02224342, 0.00963508, 0.01406986], [ 0.00703004, -0.00322611, -0.00702881, -0.01058146, -0.0172983 , 0.01252921, 0.0164896 , -0.02115302, 0.00909777, 0.01414106], [ 0.00707961, -0.00277506, -0.00822435, -0.012061 , -0.01905253, 0.01425118, 0.01821697, -0.02286081, 0.00977358, 0.0156524 ], [ 0.00743417, -0.00349867, -0.00674274, -0.01121005, -0.01663972, 0.01160948, 0.01578174, -0.02009803, 0.00905867, 0.01430515], [ 0.00723294, -0.00292 , -0.00681457, -0.01163769, -0.01753599, 0.01380847, 0.01807671, -0.023117 , 0.00940545, 0.01350168], [ 0.00762534, -0.00292024, -0.00618219, -0.01144542, -0.01739811, 0.01283356, 0.01715162, -0.02179941, 0.00850297, 0.01363187], [ 0.00633381, -0.00285126, -0.00788731, -0.01244748, -0.01820611, 0.01285803, 0.01839828, -0.02255801, 0.01036778, 0.01599226], [ 0.0063688 , -0.00264735, -0.00612585, -0.01024647, -0.0142428 , 0.01111668, 0.01447214, -0.01960347, 0.00815543, 0.01275289], [ 0.00744093, -0.00281287, -0.00659018, -0.01315726, -0.01916646, 0.01439983, 0.01806754, -0.02301537, 0.00978759, 0.01504624], [ 0.00593614, -0.00279319, -0.00624032, -0.00996988, -0.01737705, 0.01254924, 0.01642216, -0.02028641, 0.00830488, 0.01345443], [ 0.00696972, -0.00305455, -0.00741149, -0.01162382, -0.01752938, 0.01311472, 0.01729446, -0.02192618, 0.00940019, 0.01476633], [ 0.00745429, -0.00262863, -0.00622225, -0.01084687, -0.01693095, 0.01216336, 0.01625168, -0.02227478, 0.00914495, 0.0138892 ], [ 0.00645842, -0.00327809, -0.00725738, -0.01033082, -0.01683012, 0.01286794, 0.01796497, -0.02240907, 0.00910239, 0.01371176], [ 0.00782254, -0.00268519, -0.00630606, -0.01091048, -0.0163355 , 0.01223613, 0.01621302, -0.02204224, 0.00883432, 0.01317347], [ 0.00781965, -0.00239414, -0.00699459, -0.01152351, -0.01900172, 0.01399843, 0.01687553, -0.02283437, 0.01030335, 0.01375138], [ 0.00829657, -0.00342154, -0.00658943, -0.0127997 , -0.01871837, 0.01318246, 0.01916038, -0.02424938, 0.00984068, 0.01529832], [ 0.0066726 , -0.00167694, -0.00735458, -0.01173536, -0.01731281, 0.01316773, 0.01759515, -0.02214079, 0.00837935, 0.01440566], [ 0.00619801, -0.00199221, -0.00707425, -0.01127906, -0.01684923, 0.01245846, 0.01671182, -0.01964692, 0.00840533, 0.01306804], [ 0.00678763, -0.00337781, -0.00621812, -0.01055075, -0.01545997, 0.01227739, 0.01632501, -0.02172236, 0.00911671, 0.01282227], [ 0.00685344, -0.00227337, -0.00751917, -0.0117559 , -0.01680292, 0.01273094, 0.01759321, -0.02127769, 0.00887113, 0.01358032], [ 0.00722871, -0.00352505, -0.00706737, -0.01064067, -0.01757935, 0.0123686 , 0.01774125, -0.02139753, 0.00918537, 0.01368604], [ 0.00753048, -0.00197503, -0.00693526, -0.01205653, -0.01671394, 0.0129217 , 0.01535214, -0.0196623 , 0.00872553, 0.0128132 ], [ 0.00676293, -0.00253889, -0.00632378, -0.01294212, -0.01833538, 0.01358414, 0.01783908, -0.02197842, 0.00964359, 0.01428884], [ 0.00736326, -0.00293197, -0.00746632, -0.01132296, -0.01740848, 0.01261536, 0.01836238, -0.0220969 , 0.00922521, 0.01366044], [ 0.00550899, -0.0018352 , -0.00770857, -0.01055594, -0.01577571, 0.01225048, 0.01675523, -0.02037551, 0.00779559, 0.01394064], [ 0.00803351, -0.00298502, -0.00662095, -0.01268296, -0.01989259, 0.01447088, 0.01859461, -0.02581396, 0.010352 , 0.01654448], [ 0.00620658, -0.00232938, -0.00659535, -0.01141569, -0.0150393 , 0.01259878, 0.01684573, -0.02081385, 0.00820497, 0.01233753], [ 0.0071031 , -0.00234842, -0.0069273 , -0.01331782, -0.01972299, 0.01467843, 0.02037617, -0.02394397, 0.00919651, 0.0149063 ], [ 0.00651039, -0.003366 , -0.00644906, -0.01123773, -0.0176999 , 0.01303137, 0.01696386, -0.02137365, 0.00925646, 0.01436425], [ 0.00778053, -0.00276953, -0.00799891, -0.01354465, -0.01866871, 0.01547232, 0.01940784, -0.02551665, 0.01047816, 0.01535959], [ 0.00773522, -0.00301676, -0.00567059, -0.00978723, -0.01652495, 0.01183022, 0.0154958 , -0.01970014, 0.00734566, 0.01229276], [ 0.00825435, -0.00354247, -0.00748351, -0.01264735, -0.01957153, 0.01464398, 0.0193235 , -0.02469621, 0.01046143, 0.01525781], [ 0.00667914, -0.00271727, -0.00787535, -0.01095794, -0.01641298, 0.01315007, 0.0159249 , -0.02175046, 0.00968734, 0.01427256], [ 0.00797948, -0.00198537, -0.00781973, -0.01258672, -0.01838416, 0.01297721, 0.01848351, -0.02329603, 0.00964381, 0.014988 ], [ 0.00644917, -0.00314232, -0.00640677, -0.0106822 , -0.01555569, 0.01265284, 0.01721033, -0.02121489, 0.0082403 , 0.01244924], [ 0.00760946, -0.00233169, -0.00687916, -0.01288937, -0.01888642, 0.01453506, 0.01833869, -0.0235078 , 0.00919876, 0.01481246], [ 0.00594407, -0.00183362, -0.00706494, -0.01123643, -0.01718361, 0.01454499, 0.01707111, -0.0214884 , 0.00723888, 0.01400795], [ 0.00619733, -0.00257773, -0.00721291, -0.01048336, -0.01525531, 0.01235008, 0.01659489, -0.02012533, 0.00798495, 0.0125274 ], [ 0.00606657, -0.00203242, -0.0067406 , -0.01151133, -0.01640535, 0.01316657, 0.01689653, -0.02080129, 0.0083133 , 0.01304803], [ 0.00658241, -0.00309002, -0.00578851, -0.0125082 , -0.0165096 , 0.01313102, 0.01800215, -0.02263804, 0.00828599, 0.0145328 ], [ 0.00721863, -0.00253089, -0.00760832, -0.01219673, -0.01729639, 0.01354151, 0.01789652, -0.02174683, 0.00910849, 0.01361401], [ 0.00739029, -0.0026003 , -0.00682274, -0.01233208, -0.01847157, 0.01375426, 0.01848348, -0.02367317, 0.0094079 , 0.01486394], [ 0.0066208 , -0.00190203, -0.00580189, -0.01161505, -0.01752498, 0.01386346, 0.01756235, -0.02369631, 0.00920909, 0.01328456], [ 0.00569988, -0.00233706, -0.00638336, -0.0098777 , -0.01587379, 0.01140474, 0.01531575, -0.01850868, 0.00853003, 0.0120302 ], [ 0.00632813, -0.00259109, -0.00634574, -0.01164946, -0.01807884, 0.01335745, 0.01657973, -0.02112766, 0.00966946, 0.01385802], [ 0.00616539, -0.00259266, -0.00626645, -0.01026493, -0.0160082 , 0.01176927, 0.01559857, -0.01995238, 0.00863871, 0.01291269], [ 0.00579846, -0.00265793, -0.00778687, -0.01141013, -0.01793622, 0.01399278, 0.0180803 , -0.02069128, 0.00918905, 0.01342183], [ 0.00658597, -0.00227076, -0.00661027, -0.01039896, -0.01654789, 0.01209247, 0.01602801, -0.01991769, 0.00788264, 0.01315648], [ 0.00503831, -0.0032357 , -0.0059146 , -0.00967524, -0.01648853, 0.01239182, 0.0165123 , -0.01889283, 0.00815006, 0.01211441], [ 0.00743022, -0.00289324, -0.00726852, -0.0107098 , -0.01556678, 0.01153475, 0.01637804, -0.02010757, 0.00806969, 0.0131332 ], [ 0.00714878, -0.00229606, -0.0072071 , -0.01029946, -0.0152976 , 0.01169801, 0.01592791, -0.02039676, 0.00716047, 0.01356179], [ 0.00678711, -0.00243385, -0.00695318, -0.01173005, -0.01752496, 0.0147561 , 0.01705073, -0.02304314, 0.00925115, 0.0138401 ], [ 0.00669906, -0.0034713 , -0.00760504, -0.01189727, -0.01814793, 0.01389118, 0.01929111, -0.02359835, 0.00996228, 0.01487626], [ 0.00513117, -0.00086564, -0.00571698, -0.01136569, -0.01473881, 0.01229891, 0.01388429, -0.01773186, 0.00750747, 0.01159714], [ 0.00708204, -0.00235589, -0.00772639, -0.01181306, -0.01874615, 0.01480686, 0.01704496, -0.02202492, 0.00962315, 0.01410941], [ 0.00581206, -0.0015337 , -0.00560499, -0.01168892, -0.01520383, 0.01119929, 0.01601909, -0.02095507, 0.00871162, 0.01324445], [ 0.00681396, -0.00271553, -0.00766214, -0.01055913, -0.0156273 , 0.01240939, 0.01654221, -0.02048838, 0.00908928, 0.01219764], [ 0.00589835, -0.00186738, -0.00667325, -0.01126853, -0.0186047 , 0.01478683, 0.0179681 , -0.02284593, 0.00912367, 0.01348284], [ 0.00758701, -0.00187084, -0.00659176, -0.01354558, -0.01767312, 0.01315451, 0.01801194, -0.02373303, 0.0101775 , 0.01448337], [ 0.00624654, -0.0030489 , -0.00709974, -0.01110273, -0.0171586 , 0.01336342, 0.01722309, -0.02084729, 0.00863516, 0.01378904], [ 0.00661563, -0.00336753, -0.00699037, -0.01308796, -0.01901668, 0.01518163, 0.01905076, -0.02342177, 0.00938174, 0.01565457], [ 0.00643186, -0.0027529 , -0.00682291, -0.01223604, -0.01773754, 0.01390935, 0.0165292 , -0.02097263, 0.00940504, 0.01424658], [ 0.00694553, -0.0031564 , -0.0063375 , -0.01211942, -0.01644379, 0.01165001, 0.01754997, -0.019914 , 0.00902677, 0.01279883], [ 0.00633283, -0.00238536, -0.00626402, -0.01099482, -0.01640751, 0.01241742, 0.01668449, -0.02054633, 0.00840588, 0.01275741], [ 0.00759695, -0.00228049, -0.00570196, -0.01181649, -0.01798147, 0.01314141, 0.01644045, -0.02063274, 0.00786479, 0.01336953], [ 0.00611332, -0.00298459, -0.00737046, -0.01092826, -0.01666005, 0.01312359, 0.01544784, -0.01920816, 0.00824926, 0.01421751], [ 0.00640914, -0.00166262, -0.00573378, -0.01172273, -0.01657923, 0.01248359, 0.01663657, -0.0221241 , 0.00882959, 0.01346354], [ 0.00687829, -0.00179952, -0.00613484, -0.01069485, -0.01508977, 0.01161157, 0.01494313, -0.01996936, 0.00757321, 0.01268215], [ 0.00673551, -0.00299859, -0.00650711, -0.01276192, -0.01657454, 0.01431217, 0.01698134, -0.02177343, 0.00855272, 0.01403384], [ 0.0071481 , -0.00240756, -0.00695551, -0.01128112, -0.01855657, 0.01374595, 0.01792691, -0.02349617, 0.009356 , 0.01451997], [ 0.00627825, -0.00282215, -0.00783109, -0.01224586, -0.01811497, 0.01464971, 0.0187777 , -0.02321358, 0.0094544 , 0.0150676 ], [ 0.00736569, -0.00393931, -0.00698537, -0.01152828, -0.0173578 , 0.01357202, 0.01679939, -0.02187521, 0.00919141, 0.01475746], [ 0.00714259, -0.00252889, -0.00710874, -0.011057 , -0.01673003, 0.01237125, 0.01753978, -0.02282367, 0.00839725, 0.01479746], [ 0.00805957, -0.00187179, -0.00809836, -0.01108546, -0.01682495, 0.0125567 , 0.01615512, -0.02060704, 0.0084233 , 0.01329291], [ 0.00617416, -0.00290558, -0.00658438, -0.01030441, -0.0164438 , 0.01275353, 0.01665908, -0.02119893, 0.00904894, 0.01280138], [ 0.00786255, -0.00271548, -0.00635867, -0.01291163, -0.01801447, 0.01335777, 0.01767563, -0.02392931, 0.01021092, 0.0148227 ], [ 0.0064042 , -0.00244696, -0.00642891, -0.01020994, -0.0141846 , 0.01070054, 0.01393827, -0.01733988, 0.00767004, 0.01189725], [ 0.00613616, -0.00292821, -0.00602306, -0.01002506, -0.01572601, 0.01198624, 0.01773106, -0.02257479, 0.00782919, 0.01359448], [ 0.00682311, -0.00234542, -0.00762326, -0.01086547, -0.01625764, 0.01273666, 0.01597129, -0.02010301, 0.00792418, 0.01373957]]), 'b2': array([ 0.01346497, -0.00513228, -0.01366098, -0.02247379, -0.03373059, 0.0258257 , 0.0335714 , -0.04273316, 0.01776345, 0.02710529])}
numerical_gradient を使った低速版
!date
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
#from two_layer_net import TwoLayerNet
%matplotlib inline
from datetime import datetime
# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
train_loss_list = []
# ハイパーパラメータ
iters_num = 30
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
for i in range(iters_num):
# ミニバッチの取得
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
# 勾配の計算
grad = network.numerical_gradient(x_batch, t_batch)
#grad = network.gradient(x_batch, t_batch) # 高速版!
# パラメータの更新
for key in ('W1', 'b1', 'W2', 'b2'):
network.params[key] -= learning_rate * grad[key]
# 学習経過の記録
loss = network.loss(x_batch, t_batch)
if i%10 == 0: print(i, loss, datetime.now().strftime("%Y/%m/%d %H:%M:%S"))
train_loss_list.append(loss)
!date
Fri Feb 3 13:39:00 JST 2017 0 2.29433450806 2017/02/03 13:39:38 10 2.30151019526 2017/02/03 13:47:42 20 2.29443076567 2017/02/03 13:58:31 Fri Feb 3 14:07:49 JST 2017
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(train_loss_list)
plt.show()
gradient を使った高速版
!date
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
#from two_layer_net import TwoLayerNet
%matplotlib inline
from datetime import datetime
# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
train_loss_list = []
# ハイパーパラメータ
iters_num = 10000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
for i in range(iters_num):
# ミニバッチの取得
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
# 勾配の計算
#grad = network.numerical_gradient(x_batch, t_batch)
grad = network.gradient(x_batch, t_batch) # 高速版!
# パラメータの更新
for key in ('W1', 'b1', 'W2', 'b2'):
network.params[key] -= learning_rate * grad[key]
# 学習経過の記録
loss = network.loss(x_batch, t_batch)
if i%1000 == 0: print(i, loss, datetime.now().strftime("%Y/%m/%d %H:%M:%S"))
train_loss_list.append(loss)
!date
Fri Feb 3 14:11:27 JST 2017 0 2.29596964637 2017/02/03 14:11:28 1000 0.590789722964 2017/02/03 14:11:30 2000 0.286607804041 2017/02/03 14:11:32 3000 0.298551722558 2017/02/03 14:11:35 4000 0.289130595618 2017/02/03 14:11:39 5000 0.226502528465 2017/02/03 14:11:42 6000 0.340328112376 2017/02/03 14:11:44 7000 0.248002214471 2017/02/03 14:11:47 8000 0.158586735178 2017/02/03 14:11:50 9000 0.212563041588 2017/02/03 14:11:53 Fri Feb 3 14:11:55 JST 2017
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(train_loss_list)
plt.show()
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
#from two_layer_net import TwoLayerNet
%matplotlib inline
# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
iters_num = 10000 # 繰り返しの回数を適宜設定する
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
for i in range(iters_num):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
# 勾配の計算
#grad = network.numerical_gradient(x_batch, t_batch)
grad = network.gradient(x_batch, t_batch)
# パラメータの更新
for key in ('W1', 'b1', 'W2', 'b2'):
network.params[key] -= learning_rate * grad[key]
loss = network.loss(x_batch, t_batch)
train_loss_list.append(loss)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))
# グラフの描画
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()
train acc, test acc | 0.112366666667, 0.1135 train acc, test acc | 0.778966666667, 0.7822 train acc, test acc | 0.874666666667, 0.8781 train acc, test acc | 0.8979, 0.9006 train acc, test acc | 0.90875, 0.912 train acc, test acc | 0.91475, 0.9173 train acc, test acc | 0.919766666667, 0.9215 train acc, test acc | 0.923966666667, 0.9252 train acc, test acc | 0.92745, 0.929 train acc, test acc | 0.931683333333, 0.9332 train acc, test acc | 0.934216666667, 0.9342 train acc, test acc | 0.936783333333, 0.9356 train acc, test acc | 0.93955, 0.9385 train acc, test acc | 0.9422, 0.9388 train acc, test acc | 0.944066666667, 0.9412 train acc, test acc | 0.945466666667, 0.9421 train acc, test acc | 0.947266666667, 0.9446