ニューラルネットワークで問題を解く

ニューラルネットワークで、スパイラルデータセットのクラス分けを行う。 使用するのは隠れ層がひとつのニューラルネットワークを使用する。

class TwoLayerNet:
    def __init__(self, input_size, hidden_size, output_size):
        I, H, O = input_size, hidden_size, output_size

        # 重みとバイアスの初期化
        W1 = 0.01 * np.random.randn(I, H)
        b1 = np.zeros(H)
        W2 = 0.01 * np.random.randn(H, O)
        b2 = np.zeros(O)

        # レイヤの生成
        self.layers = [
            Affine(W1, b1),
            Sigmoid(),
            Affine(W2, b2)
        ]
        self.loss_layer = SoftmaxWithLoss()

        # すべての重みと勾配をリストにまとめる
        self.params, self.grads = [], []
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads

    def predict(self, x):
        for layer in self.layers:
            x = layer.forward(x)
        return x

    def forward(self, x, t):
        score = self.predict(x)
        loss = self.loss_layer.forward(score, t)
        return loss

    def backward(self, dout=1):
        dout = self.loss_layer.backward(dout)
        for layer in reversed(self.layers):
            dout = layer.backward(dout)
        return dout

使用するデータはスパイラルデータセット。単純な関数では区分けが難しい。

In [10]:
import matplotlib.pyplot as plt

x, t = spiral.load_data()
CLS_NUM = 3
markers = ['o', 'x', '^']

for i in range(CLS_NUM):
    plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()
In [11]:
# coding: utf-8
import sys
sys.path.append('..')  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
from common.optimizer import SGD
from dataset import spiral
import matplotlib.pyplot as plt
from two_layer_net import TwoLayerNet


# ハイパーパラメータの設定
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0

x, t = spiral.load_data()
model = TwoLayerNet(input_size=2, hidden_size=hidden_size, output_size=3)
optimizer = SGD(lr=learning_rate)

# 学習で使用する変数
data_size = len(x)
max_iters = data_size // batch_size
total_loss = 0
loss_count = 0
loss_list = []

for epoch in range(max_epoch):
    # データのシャッフル
    idx = np.random.permutation(data_size)
    x = x[idx]
    t = t[idx]

    for iters in range(max_iters):
        batch_x = x[iters*batch_size:(iters+1)*batch_size]
        batch_t = t[iters*batch_size:(iters+1)*batch_size]

        # 勾配を求め、パラメータを更新
        loss = model.forward(batch_x, batch_t)
        model.backward()
        optimizer.update(model.params, model.grads)

        total_loss += loss
        loss_count += 1

        # 定期的に学習経過を出力
        if (iters+1) % 10 == 0:
            avg_loss = total_loss / loss_count
            print('| epoch %d |  iter %d / %d | loss %.2f'
                  % (epoch + 1, iters + 1, max_iters, avg_loss))
            loss_list.append(avg_loss)
            total_loss, loss_count = 0, 0


# 学習結果のプロット
plt.plot(np.arange(len(loss_list)), loss_list, label='train')
plt.xlabel('iterations (x10)')
plt.ylabel('loss')
plt.show()

# 境界領域のプロット
h = 0.001
x_min, x_max = x[:, 0].min() - .1, x[:, 0].max() + .1
y_min, y_max = x[:, 1].min() - .1, x[:, 1].max() + .1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
X = np.c_[xx.ravel(), yy.ravel()]
score = model.predict(X)
predict_cls = np.argmax(score, axis=1)
Z = predict_cls.reshape(xx.shape)
plt.contourf(xx, yy, Z)
plt.axis('off')

# データ点のプロット
x, t = spiral.load_data()
N = 100
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
    plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()
| epoch 1 |  iter 10 / 10 | loss 1.13
| epoch 2 |  iter 10 / 10 | loss 1.13
| epoch 3 |  iter 10 / 10 | loss 1.12
| epoch 4 |  iter 10 / 10 | loss 1.12
| epoch 5 |  iter 10 / 10 | loss 1.11
| epoch 6 |  iter 10 / 10 | loss 1.14
| epoch 7 |  iter 10 / 10 | loss 1.16
| epoch 8 |  iter 10 / 10 | loss 1.11
| epoch 9 |  iter 10 / 10 | loss 1.12
| epoch 10 |  iter 10 / 10 | loss 1.13
| epoch 11 |  iter 10 / 10 | loss 1.12
| epoch 12 |  iter 10 / 10 | loss 1.11
| epoch 13 |  iter 10 / 10 | loss 1.09
| epoch 14 |  iter 10 / 10 | loss 1.08
| epoch 15 |  iter 10 / 10 | loss 1.04
| epoch 16 |  iter 10 / 10 | loss 1.03
| epoch 17 |  iter 10 / 10 | loss 0.96
| epoch 18 |  iter 10 / 10 | loss 0.92
| epoch 19 |  iter 10 / 10 | loss 0.92
| epoch 20 |  iter 10 / 10 | loss 0.87
| epoch 21 |  iter 10 / 10 | loss 0.85
| epoch 22 |  iter 10 / 10 | loss 0.82
| epoch 23 |  iter 10 / 10 | loss 0.79
| epoch 24 |  iter 10 / 10 | loss 0.78
| epoch 25 |  iter 10 / 10 | loss 0.82
| epoch 26 |  iter 10 / 10 | loss 0.78
| epoch 27 |  iter 10 / 10 | loss 0.76
| epoch 28 |  iter 10 / 10 | loss 0.76
| epoch 29 |  iter 10 / 10 | loss 0.78
| epoch 30 |  iter 10 / 10 | loss 0.75
| epoch 31 |  iter 10 / 10 | loss 0.78
| epoch 32 |  iter 10 / 10 | loss 0.77
| epoch 33 |  iter 10 / 10 | loss 0.77
| epoch 34 |  iter 10 / 10 | loss 0.78
| epoch 35 |  iter 10 / 10 | loss 0.75
| epoch 36 |  iter 10 / 10 | loss 0.74
| epoch 37 |  iter 10 / 10 | loss 0.76
| epoch 38 |  iter 10 / 10 | loss 0.76
| epoch 39 |  iter 10 / 10 | loss 0.73
| epoch 40 |  iter 10 / 10 | loss 0.75
| epoch 41 |  iter 10 / 10 | loss 0.76
| epoch 42 |  iter 10 / 10 | loss 0.76
| epoch 43 |  iter 10 / 10 | loss 0.76
| epoch 44 |  iter 10 / 10 | loss 0.74
| epoch 45 |  iter 10 / 10 | loss 0.75
| epoch 46 |  iter 10 / 10 | loss 0.73
| epoch 47 |  iter 10 / 10 | loss 0.72
| epoch 48 |  iter 10 / 10 | loss 0.73
| epoch 49 |  iter 10 / 10 | loss 0.72
| epoch 50 |  iter 10 / 10 | loss 0.72
| epoch 51 |  iter 10 / 10 | loss 0.72
| epoch 52 |  iter 10 / 10 | loss 0.72
| epoch 53 |  iter 10 / 10 | loss 0.74
| epoch 54 |  iter 10 / 10 | loss 0.74
| epoch 55 |  iter 10 / 10 | loss 0.72
| epoch 56 |  iter 10 / 10 | loss 0.72
| epoch 57 |  iter 10 / 10 | loss 0.71
| epoch 58 |  iter 10 / 10 | loss 0.70
| epoch 59 |  iter 10 / 10 | loss 0.72
| epoch 60 |  iter 10 / 10 | loss 0.70
| epoch 61 |  iter 10 / 10 | loss 0.71
| epoch 62 |  iter 10 / 10 | loss 0.72
| epoch 63 |  iter 10 / 10 | loss 0.70
| epoch 64 |  iter 10 / 10 | loss 0.71
| epoch 65 |  iter 10 / 10 | loss 0.73
| epoch 66 |  iter 10 / 10 | loss 0.70
| epoch 67 |  iter 10 / 10 | loss 0.71
| epoch 68 |  iter 10 / 10 | loss 0.69
| epoch 69 |  iter 10 / 10 | loss 0.70
| epoch 70 |  iter 10 / 10 | loss 0.71
| epoch 71 |  iter 10 / 10 | loss 0.68
| epoch 72 |  iter 10 / 10 | loss 0.69
| epoch 73 |  iter 10 / 10 | loss 0.67
| epoch 74 |  iter 10 / 10 | loss 0.68
| epoch 75 |  iter 10 / 10 | loss 0.67
| epoch 76 |  iter 10 / 10 | loss 0.66
| epoch 77 |  iter 10 / 10 | loss 0.69
| epoch 78 |  iter 10 / 10 | loss 0.64
| epoch 79 |  iter 10 / 10 | loss 0.68
| epoch 80 |  iter 10 / 10 | loss 0.64
| epoch 81 |  iter 10 / 10 | loss 0.64
| epoch 82 |  iter 10 / 10 | loss 0.66
| epoch 83 |  iter 10 / 10 | loss 0.62
| epoch 84 |  iter 10 / 10 | loss 0.62
| epoch 85 |  iter 10 / 10 | loss 0.61
| epoch 86 |  iter 10 / 10 | loss 0.60
| epoch 87 |  iter 10 / 10 | loss 0.60
| epoch 88 |  iter 10 / 10 | loss 0.61
| epoch 89 |  iter 10 / 10 | loss 0.59
| epoch 90 |  iter 10 / 10 | loss 0.58
| epoch 91 |  iter 10 / 10 | loss 0.56
| epoch 92 |  iter 10 / 10 | loss 0.56
| epoch 93 |  iter 10 / 10 | loss 0.54
| epoch 94 |  iter 10 / 10 | loss 0.53
| epoch 95 |  iter 10 / 10 | loss 0.53
| epoch 96 |  iter 10 / 10 | loss 0.52
| epoch 97 |  iter 10 / 10 | loss 0.51
| epoch 98 |  iter 10 / 10 | loss 0.50
| epoch 99 |  iter 10 / 10 | loss 0.48
| epoch 100 |  iter 10 / 10 | loss 0.48
| epoch 101 |  iter 10 / 10 | loss 0.46
| epoch 102 |  iter 10 / 10 | loss 0.45
| epoch 103 |  iter 10 / 10 | loss 0.45
| epoch 104 |  iter 10 / 10 | loss 0.44
| epoch 105 |  iter 10 / 10 | loss 0.44
| epoch 106 |  iter 10 / 10 | loss 0.41
| epoch 107 |  iter 10 / 10 | loss 0.40
| epoch 108 |  iter 10 / 10 | loss 0.41
| epoch 109 |  iter 10 / 10 | loss 0.40
| epoch 110 |  iter 10 / 10 | loss 0.40
| epoch 111 |  iter 10 / 10 | loss 0.38
| epoch 112 |  iter 10 / 10 | loss 0.38
| epoch 113 |  iter 10 / 10 | loss 0.36
| epoch 114 |  iter 10 / 10 | loss 0.37
| epoch 115 |  iter 10 / 10 | loss 0.35
| epoch 116 |  iter 10 / 10 | loss 0.34
| epoch 117 |  iter 10 / 10 | loss 0.34
| epoch 118 |  iter 10 / 10 | loss 0.34
| epoch 119 |  iter 10 / 10 | loss 0.33
| epoch 120 |  iter 10 / 10 | loss 0.34
| epoch 121 |  iter 10 / 10 | loss 0.32
| epoch 122 |  iter 10 / 10 | loss 0.32
| epoch 123 |  iter 10 / 10 | loss 0.31
| epoch 124 |  iter 10 / 10 | loss 0.31
| epoch 125 |  iter 10 / 10 | loss 0.30
| epoch 126 |  iter 10 / 10 | loss 0.30
| epoch 127 |  iter 10 / 10 | loss 0.28
| epoch 128 |  iter 10 / 10 | loss 0.28
| epoch 129 |  iter 10 / 10 | loss 0.28
| epoch 130 |  iter 10 / 10 | loss 0.28
| epoch 131 |  iter 10 / 10 | loss 0.27
| epoch 132 |  iter 10 / 10 | loss 0.27
| epoch 133 |  iter 10 / 10 | loss 0.27
| epoch 134 |  iter 10 / 10 | loss 0.27
| epoch 135 |  iter 10 / 10 | loss 0.27
| epoch 136 |  iter 10 / 10 | loss 0.26
| epoch 137 |  iter 10 / 10 | loss 0.26
| epoch 138 |  iter 10 / 10 | loss 0.26
| epoch 139 |  iter 10 / 10 | loss 0.25
| epoch 140 |  iter 10 / 10 | loss 0.24
| epoch 141 |  iter 10 / 10 | loss 0.24
| epoch 142 |  iter 10 / 10 | loss 0.25
| epoch 143 |  iter 10 / 10 | loss 0.24
| epoch 144 |  iter 10 / 10 | loss 0.24
| epoch 145 |  iter 10 / 10 | loss 0.23
| epoch 146 |  iter 10 / 10 | loss 0.24
| epoch 147 |  iter 10 / 10 | loss 0.23
| epoch 148 |  iter 10 / 10 | loss 0.23
| epoch 149 |  iter 10 / 10 | loss 0.22
| epoch 150 |  iter 10 / 10 | loss 0.22
| epoch 151 |  iter 10 / 10 | loss 0.22
| epoch 152 |  iter 10 / 10 | loss 0.22
| epoch 153 |  iter 10 / 10 | loss 0.22
| epoch 154 |  iter 10 / 10 | loss 0.22
| epoch 155 |  iter 10 / 10 | loss 0.22
| epoch 156 |  iter 10 / 10 | loss 0.21
| epoch 157 |  iter 10 / 10 | loss 0.21
| epoch 158 |  iter 10 / 10 | loss 0.20
| epoch 159 |  iter 10 / 10 | loss 0.21
| epoch 160 |  iter 10 / 10 | loss 0.20
| epoch 161 |  iter 10 / 10 | loss 0.20
| epoch 162 |  iter 10 / 10 | loss 0.20
| epoch 163 |  iter 10 / 10 | loss 0.21
| epoch 164 |  iter 10 / 10 | loss 0.20
| epoch 165 |  iter 10 / 10 | loss 0.20
| epoch 166 |  iter 10 / 10 | loss 0.19
| epoch 167 |  iter 10 / 10 | loss 0.19
| epoch 168 |  iter 10 / 10 | loss 0.19
| epoch 169 |  iter 10 / 10 | loss 0.19
| epoch 170 |  iter 10 / 10 | loss 0.19
| epoch 171 |  iter 10 / 10 | loss 0.19
| epoch 172 |  iter 10 / 10 | loss 0.18
| epoch 173 |  iter 10 / 10 | loss 0.18
| epoch 174 |  iter 10 / 10 | loss 0.18
| epoch 175 |  iter 10 / 10 | loss 0.18
| epoch 176 |  iter 10 / 10 | loss 0.18
| epoch 177 |  iter 10 / 10 | loss 0.18
| epoch 178 |  iter 10 / 10 | loss 0.18
| epoch 179 |  iter 10 / 10 | loss 0.17
| epoch 180 |  iter 10 / 10 | loss 0.17
| epoch 181 |  iter 10 / 10 | loss 0.18
| epoch 182 |  iter 10 / 10 | loss 0.17
| epoch 183 |  iter 10 / 10 | loss 0.18
| epoch 184 |  iter 10 / 10 | loss 0.17
| epoch 185 |  iter 10 / 10 | loss 0.17
| epoch 186 |  iter 10 / 10 | loss 0.18
| epoch 187 |  iter 10 / 10 | loss 0.17
| epoch 188 |  iter 10 / 10 | loss 0.17
| epoch 189 |  iter 10 / 10 | loss 0.17
| epoch 190 |  iter 10 / 10 | loss 0.17
| epoch 191 |  iter 10 / 10 | loss 0.16
| epoch 192 |  iter 10 / 10 | loss 0.17
| epoch 193 |  iter 10 / 10 | loss 0.16
| epoch 194 |  iter 10 / 10 | loss 0.16
| epoch 195 |  iter 10 / 10 | loss 0.16
| epoch 196 |  iter 10 / 10 | loss 0.16
| epoch 197 |  iter 10 / 10 | loss 0.16
| epoch 198 |  iter 10 / 10 | loss 0.15
| epoch 199 |  iter 10 / 10 | loss 0.16
| epoch 200 |  iter 10 / 10 | loss 0.16
| epoch 201 |  iter 10 / 10 | loss 0.15
| epoch 202 |  iter 10 / 10 | loss 0.16
| epoch 203 |  iter 10 / 10 | loss 0.16
| epoch 204 |  iter 10 / 10 | loss 0.15
| epoch 205 |  iter 10 / 10 | loss 0.16
| epoch 206 |  iter 10 / 10 | loss 0.15
| epoch 207 |  iter 10 / 10 | loss 0.15
| epoch 208 |  iter 10 / 10 | loss 0.15
| epoch 209 |  iter 10 / 10 | loss 0.15
| epoch 210 |  iter 10 / 10 | loss 0.15
| epoch 211 |  iter 10 / 10 | loss 0.15
| epoch 212 |  iter 10 / 10 | loss 0.15
| epoch 213 |  iter 10 / 10 | loss 0.15
| epoch 214 |  iter 10 / 10 | loss 0.15
| epoch 215 |  iter 10 / 10 | loss 0.15
| epoch 216 |  iter 10 / 10 | loss 0.14
| epoch 217 |  iter 10 / 10 | loss 0.14
| epoch 218 |  iter 10 / 10 | loss 0.15
| epoch 219 |  iter 10 / 10 | loss 0.14
| epoch 220 |  iter 10 / 10 | loss 0.14
| epoch 221 |  iter 10 / 10 | loss 0.14
| epoch 222 |  iter 10 / 10 | loss 0.14
| epoch 223 |  iter 10 / 10 | loss 0.14
| epoch 224 |  iter 10 / 10 | loss 0.14
| epoch 225 |  iter 10 / 10 | loss 0.14
| epoch 226 |  iter 10 / 10 | loss 0.14
| epoch 227 |  iter 10 / 10 | loss 0.14
| epoch 228 |  iter 10 / 10 | loss 0.14
| epoch 229 |  iter 10 / 10 | loss 0.13
| epoch 230 |  iter 10 / 10 | loss 0.14
| epoch 231 |  iter 10 / 10 | loss 0.13
| epoch 232 |  iter 10 / 10 | loss 0.14
| epoch 233 |  iter 10 / 10 | loss 0.13
| epoch 234 |  iter 10 / 10 | loss 0.13
| epoch 235 |  iter 10 / 10 | loss 0.13
| epoch 236 |  iter 10 / 10 | loss 0.13
| epoch 237 |  iter 10 / 10 | loss 0.14
| epoch 238 |  iter 10 / 10 | loss 0.13
| epoch 239 |  iter 10 / 10 | loss 0.13
| epoch 240 |  iter 10 / 10 | loss 0.14
| epoch 241 |  iter 10 / 10 | loss 0.13
| epoch 242 |  iter 10 / 10 | loss 0.13
| epoch 243 |  iter 10 / 10 | loss 0.13
| epoch 244 |  iter 10 / 10 | loss 0.13
| epoch 245 |  iter 10 / 10 | loss 0.13
| epoch 246 |  iter 10 / 10 | loss 0.13
| epoch 247 |  iter 10 / 10 | loss 0.13
| epoch 248 |  iter 10 / 10 | loss 0.13
| epoch 249 |  iter 10 / 10 | loss 0.13
| epoch 250 |  iter 10 / 10 | loss 0.13
| epoch 251 |  iter 10 / 10 | loss 0.13
| epoch 252 |  iter 10 / 10 | loss 0.12
| epoch 253 |  iter 10 / 10 | loss 0.12
| epoch 254 |  iter 10 / 10 | loss 0.12
| epoch 255 |  iter 10 / 10 | loss 0.12
| epoch 256 |  iter 10 / 10 | loss 0.12
| epoch 257 |  iter 10 / 10 | loss 0.12
| epoch 258 |  iter 10 / 10 | loss 0.12
| epoch 259 |  iter 10 / 10 | loss 0.13
| epoch 260 |  iter 10 / 10 | loss 0.12
| epoch 261 |  iter 10 / 10 | loss 0.13
| epoch 262 |  iter 10 / 10 | loss 0.12
| epoch 263 |  iter 10 / 10 | loss 0.12
| epoch 264 |  iter 10 / 10 | loss 0.13
| epoch 265 |  iter 10 / 10 | loss 0.12
| epoch 266 |  iter 10 / 10 | loss 0.12
| epoch 267 |  iter 10 / 10 | loss 0.12
| epoch 268 |  iter 10 / 10 | loss 0.12
| epoch 269 |  iter 10 / 10 | loss 0.11
| epoch 270 |  iter 10 / 10 | loss 0.12
| epoch 271 |  iter 10 / 10 | loss 0.12
| epoch 272 |  iter 10 / 10 | loss 0.12
| epoch 273 |  iter 10 / 10 | loss 0.12
| epoch 274 |  iter 10 / 10 | loss 0.12
| epoch 275 |  iter 10 / 10 | loss 0.11
| epoch 276 |  iter 10 / 10 | loss 0.12
| epoch 277 |  iter 10 / 10 | loss 0.12
| epoch 278 |  iter 10 / 10 | loss 0.11
| epoch 279 |  iter 10 / 10 | loss 0.11
| epoch 280 |  iter 10 / 10 | loss 0.11
| epoch 281 |  iter 10 / 10 | loss 0.11
| epoch 282 |  iter 10 / 10 | loss 0.12
| epoch 283 |  iter 10 / 10 | loss 0.11
| epoch 284 |  iter 10 / 10 | loss 0.11
| epoch 285 |  iter 10 / 10 | loss 0.11
| epoch 286 |  iter 10 / 10 | loss 0.11
| epoch 287 |  iter 10 / 10 | loss 0.11
| epoch 288 |  iter 10 / 10 | loss 0.12
| epoch 289 |  iter 10 / 10 | loss 0.11
| epoch 290 |  iter 10 / 10 | loss 0.11
| epoch 291 |  iter 10 / 10 | loss 0.11
| epoch 292 |  iter 10 / 10 | loss 0.11
| epoch 293 |  iter 10 / 10 | loss 0.11
| epoch 294 |  iter 10 / 10 | loss 0.11
| epoch 295 |  iter 10 / 10 | loss 0.12
| epoch 296 |  iter 10 / 10 | loss 0.11
| epoch 297 |  iter 10 / 10 | loss 0.12
| epoch 298 |  iter 10 / 10 | loss 0.11
| epoch 299 |  iter 10 / 10 | loss 0.11
| epoch 300 |  iter 10 / 10 | loss 0.11