ニューラルネットワークで、スパイラルデータセットのクラス分けを行う。 使用するのは隠れ層がひとつのニューラルネットワークを使用する。
class TwoLayerNet:
def __init__(self, input_size, hidden_size, output_size):
I, H, O = input_size, hidden_size, output_size
# 重みとバイアスの初期化
W1 = 0.01 * np.random.randn(I, H)
b1 = np.zeros(H)
W2 = 0.01 * np.random.randn(H, O)
b2 = np.zeros(O)
# レイヤの生成
self.layers = [
Affine(W1, b1),
Sigmoid(),
Affine(W2, b2)
]
self.loss_layer = SoftmaxWithLoss()
# すべての重みと勾配をリストにまとめる
self.params, self.grads = [], []
for layer in self.layers:
self.params += layer.params
self.grads += layer.grads
def predict(self, x):
for layer in self.layers:
x = layer.forward(x)
return x
def forward(self, x, t):
score = self.predict(x)
loss = self.loss_layer.forward(score, t)
return loss
def backward(self, dout=1):
dout = self.loss_layer.backward(dout)
for layer in reversed(self.layers):
dout = layer.backward(dout)
return dout
使用するデータはスパイラルデータセット。単純な関数では区分けが難しい。
import matplotlib.pyplot as plt
x, t = spiral.load_data()
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()
# coding: utf-8
import sys
sys.path.append('..') # 親ディレクトリのファイルをインポートするための設定
import numpy as np
from common.optimizer import SGD
from dataset import spiral
import matplotlib.pyplot as plt
from two_layer_net import TwoLayerNet
# ハイパーパラメータの設定
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0
x, t = spiral.load_data()
model = TwoLayerNet(input_size=2, hidden_size=hidden_size, output_size=3)
optimizer = SGD(lr=learning_rate)
# 学習で使用する変数
data_size = len(x)
max_iters = data_size // batch_size
total_loss = 0
loss_count = 0
loss_list = []
for epoch in range(max_epoch):
# データのシャッフル
idx = np.random.permutation(data_size)
x = x[idx]
t = t[idx]
for iters in range(max_iters):
batch_x = x[iters*batch_size:(iters+1)*batch_size]
batch_t = t[iters*batch_size:(iters+1)*batch_size]
# 勾配を求め、パラメータを更新
loss = model.forward(batch_x, batch_t)
model.backward()
optimizer.update(model.params, model.grads)
total_loss += loss
loss_count += 1
# 定期的に学習経過を出力
if (iters+1) % 10 == 0:
avg_loss = total_loss / loss_count
print('| epoch %d | iter %d / %d | loss %.2f'
% (epoch + 1, iters + 1, max_iters, avg_loss))
loss_list.append(avg_loss)
total_loss, loss_count = 0, 0
# 学習結果のプロット
plt.plot(np.arange(len(loss_list)), loss_list, label='train')
plt.xlabel('iterations (x10)')
plt.ylabel('loss')
plt.show()
# 境界領域のプロット
h = 0.001
x_min, x_max = x[:, 0].min() - .1, x[:, 0].max() + .1
y_min, y_max = x[:, 1].min() - .1, x[:, 1].max() + .1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
X = np.c_[xx.ravel(), yy.ravel()]
score = model.predict(X)
predict_cls = np.argmax(score, axis=1)
Z = predict_cls.reshape(xx.shape)
plt.contourf(xx, yy, Z)
plt.axis('off')
# データ点のプロット
x, t = spiral.load_data()
N = 100
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()
| epoch 1 | iter 10 / 10 | loss 1.13 | epoch 2 | iter 10 / 10 | loss 1.13 | epoch 3 | iter 10 / 10 | loss 1.12 | epoch 4 | iter 10 / 10 | loss 1.12 | epoch 5 | iter 10 / 10 | loss 1.11 | epoch 6 | iter 10 / 10 | loss 1.14 | epoch 7 | iter 10 / 10 | loss 1.16 | epoch 8 | iter 10 / 10 | loss 1.11 | epoch 9 | iter 10 / 10 | loss 1.12 | epoch 10 | iter 10 / 10 | loss 1.13 | epoch 11 | iter 10 / 10 | loss 1.12 | epoch 12 | iter 10 / 10 | loss 1.11 | epoch 13 | iter 10 / 10 | loss 1.09 | epoch 14 | iter 10 / 10 | loss 1.08 | epoch 15 | iter 10 / 10 | loss 1.04 | epoch 16 | iter 10 / 10 | loss 1.03 | epoch 17 | iter 10 / 10 | loss 0.96 | epoch 18 | iter 10 / 10 | loss 0.92 | epoch 19 | iter 10 / 10 | loss 0.92 | epoch 20 | iter 10 / 10 | loss 0.87 | epoch 21 | iter 10 / 10 | loss 0.85 | epoch 22 | iter 10 / 10 | loss 0.82 | epoch 23 | iter 10 / 10 | loss 0.79 | epoch 24 | iter 10 / 10 | loss 0.78 | epoch 25 | iter 10 / 10 | loss 0.82 | epoch 26 | iter 10 / 10 | loss 0.78 | epoch 27 | iter 10 / 10 | loss 0.76 | epoch 28 | iter 10 / 10 | loss 0.76 | epoch 29 | iter 10 / 10 | loss 0.78 | epoch 30 | iter 10 / 10 | loss 0.75 | epoch 31 | iter 10 / 10 | loss 0.78 | epoch 32 | iter 10 / 10 | loss 0.77 | epoch 33 | iter 10 / 10 | loss 0.77 | epoch 34 | iter 10 / 10 | loss 0.78 | epoch 35 | iter 10 / 10 | loss 0.75 | epoch 36 | iter 10 / 10 | loss 0.74 | epoch 37 | iter 10 / 10 | loss 0.76 | epoch 38 | iter 10 / 10 | loss 0.76 | epoch 39 | iter 10 / 10 | loss 0.73 | epoch 40 | iter 10 / 10 | loss 0.75 | epoch 41 | iter 10 / 10 | loss 0.76 | epoch 42 | iter 10 / 10 | loss 0.76 | epoch 43 | iter 10 / 10 | loss 0.76 | epoch 44 | iter 10 / 10 | loss 0.74 | epoch 45 | iter 10 / 10 | loss 0.75 | epoch 46 | iter 10 / 10 | loss 0.73 | epoch 47 | iter 10 / 10 | loss 0.72 | epoch 48 | iter 10 / 10 | loss 0.73 | epoch 49 | iter 10 / 10 | loss 0.72 | epoch 50 | iter 10 / 10 | loss 0.72 | epoch 51 | iter 10 / 10 | loss 0.72 | epoch 52 | iter 10 / 10 | loss 0.72 | epoch 53 | iter 10 / 10 | loss 0.74 | epoch 54 | iter 10 / 10 | loss 0.74 | epoch 55 | iter 10 / 10 | loss 0.72 | epoch 56 | iter 10 / 10 | loss 0.72 | epoch 57 | iter 10 / 10 | loss 0.71 | epoch 58 | iter 10 / 10 | loss 0.70 | epoch 59 | iter 10 / 10 | loss 0.72 | epoch 60 | iter 10 / 10 | loss 0.70 | epoch 61 | iter 10 / 10 | loss 0.71 | epoch 62 | iter 10 / 10 | loss 0.72 | epoch 63 | iter 10 / 10 | loss 0.70 | epoch 64 | iter 10 / 10 | loss 0.71 | epoch 65 | iter 10 / 10 | loss 0.73 | epoch 66 | iter 10 / 10 | loss 0.70 | epoch 67 | iter 10 / 10 | loss 0.71 | epoch 68 | iter 10 / 10 | loss 0.69 | epoch 69 | iter 10 / 10 | loss 0.70 | epoch 70 | iter 10 / 10 | loss 0.71 | epoch 71 | iter 10 / 10 | loss 0.68 | epoch 72 | iter 10 / 10 | loss 0.69 | epoch 73 | iter 10 / 10 | loss 0.67 | epoch 74 | iter 10 / 10 | loss 0.68 | epoch 75 | iter 10 / 10 | loss 0.67 | epoch 76 | iter 10 / 10 | loss 0.66 | epoch 77 | iter 10 / 10 | loss 0.69 | epoch 78 | iter 10 / 10 | loss 0.64 | epoch 79 | iter 10 / 10 | loss 0.68 | epoch 80 | iter 10 / 10 | loss 0.64 | epoch 81 | iter 10 / 10 | loss 0.64 | epoch 82 | iter 10 / 10 | loss 0.66 | epoch 83 | iter 10 / 10 | loss 0.62 | epoch 84 | iter 10 / 10 | loss 0.62 | epoch 85 | iter 10 / 10 | loss 0.61 | epoch 86 | iter 10 / 10 | loss 0.60 | epoch 87 | iter 10 / 10 | loss 0.60 | epoch 88 | iter 10 / 10 | loss 0.61 | epoch 89 | iter 10 / 10 | loss 0.59 | epoch 90 | iter 10 / 10 | loss 0.58 | epoch 91 | iter 10 / 10 | loss 0.56 | epoch 92 | iter 10 / 10 | loss 0.56 | epoch 93 | iter 10 / 10 | loss 0.54 | epoch 94 | iter 10 / 10 | loss 0.53 | epoch 95 | iter 10 / 10 | loss 0.53 | epoch 96 | iter 10 / 10 | loss 0.52 | epoch 97 | iter 10 / 10 | loss 0.51 | epoch 98 | iter 10 / 10 | loss 0.50 | epoch 99 | iter 10 / 10 | loss 0.48 | epoch 100 | iter 10 / 10 | loss 0.48 | epoch 101 | iter 10 / 10 | loss 0.46 | epoch 102 | iter 10 / 10 | loss 0.45 | epoch 103 | iter 10 / 10 | loss 0.45 | epoch 104 | iter 10 / 10 | loss 0.44 | epoch 105 | iter 10 / 10 | loss 0.44 | epoch 106 | iter 10 / 10 | loss 0.41 | epoch 107 | iter 10 / 10 | loss 0.40 | epoch 108 | iter 10 / 10 | loss 0.41 | epoch 109 | iter 10 / 10 | loss 0.40 | epoch 110 | iter 10 / 10 | loss 0.40 | epoch 111 | iter 10 / 10 | loss 0.38 | epoch 112 | iter 10 / 10 | loss 0.38 | epoch 113 | iter 10 / 10 | loss 0.36 | epoch 114 | iter 10 / 10 | loss 0.37 | epoch 115 | iter 10 / 10 | loss 0.35 | epoch 116 | iter 10 / 10 | loss 0.34 | epoch 117 | iter 10 / 10 | loss 0.34 | epoch 118 | iter 10 / 10 | loss 0.34 | epoch 119 | iter 10 / 10 | loss 0.33 | epoch 120 | iter 10 / 10 | loss 0.34 | epoch 121 | iter 10 / 10 | loss 0.32 | epoch 122 | iter 10 / 10 | loss 0.32 | epoch 123 | iter 10 / 10 | loss 0.31 | epoch 124 | iter 10 / 10 | loss 0.31 | epoch 125 | iter 10 / 10 | loss 0.30 | epoch 126 | iter 10 / 10 | loss 0.30 | epoch 127 | iter 10 / 10 | loss 0.28 | epoch 128 | iter 10 / 10 | loss 0.28 | epoch 129 | iter 10 / 10 | loss 0.28 | epoch 130 | iter 10 / 10 | loss 0.28 | epoch 131 | iter 10 / 10 | loss 0.27 | epoch 132 | iter 10 / 10 | loss 0.27 | epoch 133 | iter 10 / 10 | loss 0.27 | epoch 134 | iter 10 / 10 | loss 0.27 | epoch 135 | iter 10 / 10 | loss 0.27 | epoch 136 | iter 10 / 10 | loss 0.26 | epoch 137 | iter 10 / 10 | loss 0.26 | epoch 138 | iter 10 / 10 | loss 0.26 | epoch 139 | iter 10 / 10 | loss 0.25 | epoch 140 | iter 10 / 10 | loss 0.24 | epoch 141 | iter 10 / 10 | loss 0.24 | epoch 142 | iter 10 / 10 | loss 0.25 | epoch 143 | iter 10 / 10 | loss 0.24 | epoch 144 | iter 10 / 10 | loss 0.24 | epoch 145 | iter 10 / 10 | loss 0.23 | epoch 146 | iter 10 / 10 | loss 0.24 | epoch 147 | iter 10 / 10 | loss 0.23 | epoch 148 | iter 10 / 10 | loss 0.23 | epoch 149 | iter 10 / 10 | loss 0.22 | epoch 150 | iter 10 / 10 | loss 0.22 | epoch 151 | iter 10 / 10 | loss 0.22 | epoch 152 | iter 10 / 10 | loss 0.22 | epoch 153 | iter 10 / 10 | loss 0.22 | epoch 154 | iter 10 / 10 | loss 0.22 | epoch 155 | iter 10 / 10 | loss 0.22 | epoch 156 | iter 10 / 10 | loss 0.21 | epoch 157 | iter 10 / 10 | loss 0.21 | epoch 158 | iter 10 / 10 | loss 0.20 | epoch 159 | iter 10 / 10 | loss 0.21 | epoch 160 | iter 10 / 10 | loss 0.20 | epoch 161 | iter 10 / 10 | loss 0.20 | epoch 162 | iter 10 / 10 | loss 0.20 | epoch 163 | iter 10 / 10 | loss 0.21 | epoch 164 | iter 10 / 10 | loss 0.20 | epoch 165 | iter 10 / 10 | loss 0.20 | epoch 166 | iter 10 / 10 | loss 0.19 | epoch 167 | iter 10 / 10 | loss 0.19 | epoch 168 | iter 10 / 10 | loss 0.19 | epoch 169 | iter 10 / 10 | loss 0.19 | epoch 170 | iter 10 / 10 | loss 0.19 | epoch 171 | iter 10 / 10 | loss 0.19 | epoch 172 | iter 10 / 10 | loss 0.18 | epoch 173 | iter 10 / 10 | loss 0.18 | epoch 174 | iter 10 / 10 | loss 0.18 | epoch 175 | iter 10 / 10 | loss 0.18 | epoch 176 | iter 10 / 10 | loss 0.18 | epoch 177 | iter 10 / 10 | loss 0.18 | epoch 178 | iter 10 / 10 | loss 0.18 | epoch 179 | iter 10 / 10 | loss 0.17 | epoch 180 | iter 10 / 10 | loss 0.17 | epoch 181 | iter 10 / 10 | loss 0.18 | epoch 182 | iter 10 / 10 | loss 0.17 | epoch 183 | iter 10 / 10 | loss 0.18 | epoch 184 | iter 10 / 10 | loss 0.17 | epoch 185 | iter 10 / 10 | loss 0.17 | epoch 186 | iter 10 / 10 | loss 0.18 | epoch 187 | iter 10 / 10 | loss 0.17 | epoch 188 | iter 10 / 10 | loss 0.17 | epoch 189 | iter 10 / 10 | loss 0.17 | epoch 190 | iter 10 / 10 | loss 0.17 | epoch 191 | iter 10 / 10 | loss 0.16 | epoch 192 | iter 10 / 10 | loss 0.17 | epoch 193 | iter 10 / 10 | loss 0.16 | epoch 194 | iter 10 / 10 | loss 0.16 | epoch 195 | iter 10 / 10 | loss 0.16 | epoch 196 | iter 10 / 10 | loss 0.16 | epoch 197 | iter 10 / 10 | loss 0.16 | epoch 198 | iter 10 / 10 | loss 0.15 | epoch 199 | iter 10 / 10 | loss 0.16 | epoch 200 | iter 10 / 10 | loss 0.16 | epoch 201 | iter 10 / 10 | loss 0.15 | epoch 202 | iter 10 / 10 | loss 0.16 | epoch 203 | iter 10 / 10 | loss 0.16 | epoch 204 | iter 10 / 10 | loss 0.15 | epoch 205 | iter 10 / 10 | loss 0.16 | epoch 206 | iter 10 / 10 | loss 0.15 | epoch 207 | iter 10 / 10 | loss 0.15 | epoch 208 | iter 10 / 10 | loss 0.15 | epoch 209 | iter 10 / 10 | loss 0.15 | epoch 210 | iter 10 / 10 | loss 0.15 | epoch 211 | iter 10 / 10 | loss 0.15 | epoch 212 | iter 10 / 10 | loss 0.15 | epoch 213 | iter 10 / 10 | loss 0.15 | epoch 214 | iter 10 / 10 | loss 0.15 | epoch 215 | iter 10 / 10 | loss 0.15 | epoch 216 | iter 10 / 10 | loss 0.14 | epoch 217 | iter 10 / 10 | loss 0.14 | epoch 218 | iter 10 / 10 | loss 0.15 | epoch 219 | iter 10 / 10 | loss 0.14 | epoch 220 | iter 10 / 10 | loss 0.14 | epoch 221 | iter 10 / 10 | loss 0.14 | epoch 222 | iter 10 / 10 | loss 0.14 | epoch 223 | iter 10 / 10 | loss 0.14 | epoch 224 | iter 10 / 10 | loss 0.14 | epoch 225 | iter 10 / 10 | loss 0.14 | epoch 226 | iter 10 / 10 | loss 0.14 | epoch 227 | iter 10 / 10 | loss 0.14 | epoch 228 | iter 10 / 10 | loss 0.14 | epoch 229 | iter 10 / 10 | loss 0.13 | epoch 230 | iter 10 / 10 | loss 0.14 | epoch 231 | iter 10 / 10 | loss 0.13 | epoch 232 | iter 10 / 10 | loss 0.14 | epoch 233 | iter 10 / 10 | loss 0.13 | epoch 234 | iter 10 / 10 | loss 0.13 | epoch 235 | iter 10 / 10 | loss 0.13 | epoch 236 | iter 10 / 10 | loss 0.13 | epoch 237 | iter 10 / 10 | loss 0.14 | epoch 238 | iter 10 / 10 | loss 0.13 | epoch 239 | iter 10 / 10 | loss 0.13 | epoch 240 | iter 10 / 10 | loss 0.14 | epoch 241 | iter 10 / 10 | loss 0.13 | epoch 242 | iter 10 / 10 | loss 0.13 | epoch 243 | iter 10 / 10 | loss 0.13 | epoch 244 | iter 10 / 10 | loss 0.13 | epoch 245 | iter 10 / 10 | loss 0.13 | epoch 246 | iter 10 / 10 | loss 0.13 | epoch 247 | iter 10 / 10 | loss 0.13 | epoch 248 | iter 10 / 10 | loss 0.13 | epoch 249 | iter 10 / 10 | loss 0.13 | epoch 250 | iter 10 / 10 | loss 0.13 | epoch 251 | iter 10 / 10 | loss 0.13 | epoch 252 | iter 10 / 10 | loss 0.12 | epoch 253 | iter 10 / 10 | loss 0.12 | epoch 254 | iter 10 / 10 | loss 0.12 | epoch 255 | iter 10 / 10 | loss 0.12 | epoch 256 | iter 10 / 10 | loss 0.12 | epoch 257 | iter 10 / 10 | loss 0.12 | epoch 258 | iter 10 / 10 | loss 0.12 | epoch 259 | iter 10 / 10 | loss 0.13 | epoch 260 | iter 10 / 10 | loss 0.12 | epoch 261 | iter 10 / 10 | loss 0.13 | epoch 262 | iter 10 / 10 | loss 0.12 | epoch 263 | iter 10 / 10 | loss 0.12 | epoch 264 | iter 10 / 10 | loss 0.13 | epoch 265 | iter 10 / 10 | loss 0.12 | epoch 266 | iter 10 / 10 | loss 0.12 | epoch 267 | iter 10 / 10 | loss 0.12 | epoch 268 | iter 10 / 10 | loss 0.12 | epoch 269 | iter 10 / 10 | loss 0.11 | epoch 270 | iter 10 / 10 | loss 0.12 | epoch 271 | iter 10 / 10 | loss 0.12 | epoch 272 | iter 10 / 10 | loss 0.12 | epoch 273 | iter 10 / 10 | loss 0.12 | epoch 274 | iter 10 / 10 | loss 0.12 | epoch 275 | iter 10 / 10 | loss 0.11 | epoch 276 | iter 10 / 10 | loss 0.12 | epoch 277 | iter 10 / 10 | loss 0.12 | epoch 278 | iter 10 / 10 | loss 0.11 | epoch 279 | iter 10 / 10 | loss 0.11 | epoch 280 | iter 10 / 10 | loss 0.11 | epoch 281 | iter 10 / 10 | loss 0.11 | epoch 282 | iter 10 / 10 | loss 0.12 | epoch 283 | iter 10 / 10 | loss 0.11 | epoch 284 | iter 10 / 10 | loss 0.11 | epoch 285 | iter 10 / 10 | loss 0.11 | epoch 286 | iter 10 / 10 | loss 0.11 | epoch 287 | iter 10 / 10 | loss 0.11 | epoch 288 | iter 10 / 10 | loss 0.12 | epoch 289 | iter 10 / 10 | loss 0.11 | epoch 290 | iter 10 / 10 | loss 0.11 | epoch 291 | iter 10 / 10 | loss 0.11 | epoch 292 | iter 10 / 10 | loss 0.11 | epoch 293 | iter 10 / 10 | loss 0.11 | epoch 294 | iter 10 / 10 | loss 0.11 | epoch 295 | iter 10 / 10 | loss 0.12 | epoch 296 | iter 10 / 10 | loss 0.11 | epoch 297 | iter 10 / 10 | loss 0.12 | epoch 298 | iter 10 / 10 | loss 0.11 | epoch 299 | iter 10 / 10 | loss 0.11 | epoch 300 | iter 10 / 10 | loss 0.11