import numpy as np
# 中間層の活性化関数
# シグモイド関数(ロジスティック関数)
def sigmoid(x):
return 1/(1 + np.exp(-x))
# ReLU関数
def relu(x):
return np.maximum(0, x)
# ステップ関数(閾値0)
def step_function(x):
return np.where( x > 0, 1, 0)
# 出力層の活性化関数
# ソフトマックス関数
def softmax(x):
if x.ndim == 2:
x = x.T
x = x - np.max(x, axis=0)
y = np.exp(x) / np.sum(np.exp(x), axis=0)
return y.T
x = x - np.max(x) # オーバーフロー対策
return np.exp(x) / np.sum(np.exp(x))
# ソフトマックスとクロスエントロピーの複合関数
def softmax_with_loss(d, x):
y = softmax(x)
return cross_entropy_error(d, y)
# 誤差関数
# 平均二乗誤差
def mean_squared_error(d, y):
return np.mean(np.square(d - y)) / 2
# クロスエントロピー
def cross_entropy_error(d, y):
if y.ndim == 1:
d = d.reshape(1, d.size)
y = y.reshape(1, y.size)
# 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
if d.size == y.size:
d = d.argmax(axis=1)
batch_size = y.shape[0]
return -np.sum(np.log(y[np.arange(batch_size), d] + 1e-7)) / batch_size
# 活性化関数の導関数
# シグモイド関数(ロジスティック関数)の導関数
def d_sigmoid(x):
dx = (1.0 - sigmoid(x)) * sigmoid(x)
return dx
# ReLU関数の導関数
def d_relu(x):
return np.where( x > 0, 1, 0)
# ステップ関数の導関数
def d_step_function(x):
return 0
# 平均二乗誤差の導関数
def d_mean_squared_error(d, y):
if type(d) == np.ndarray:
batch_size = d.shape[0]
dx = (y - d)/batch_size
else:
dx = y - d
return dx
# ソフトマックスとクロスエントロピーの複合導関数
def d_softmax_with_loss(d, y):
batch_size = d.shape[0]
if d.size == y.size: # 教師データがone-hot-vectorの場合
dx = (y - d) / batch_size
else:
dx = y.copy()
dx[np.arange(batch_size), d] -= 1
dx = dx / batch_size
return dx
# シグモイドとクロスエントロピーの複合導関数
def d_sigmoid_with_loss(d, y):
return y - d
# 数値微分
def numerical_gradient(f, x):
h = 1e-4
grad = np.zeros_like(x)
for idx in range(x.size):
tmp_val = x[idx]
# f(x + h)の計算
x[idx] = tmp_val + h
fxh1 = f(x)
# f(x - h)の計算
x[idx] = tmp_val - h
fxh2 = f(x)
grad[idx] = (fxh1 - fxh2) / (2 * h)
# 値を元に戻す
x[idx] = tmp_val
return grad
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)
img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
return img[:, :, pad:H + pad, pad:W + pad]
import numpy as np
# from common import functions
import matplotlib.pyplot as plt
def d_tanh(x):
return 1/(np.cosh(x) ** 2)
# データを用意
# 2進数の桁数
binary_dim = 8
# 最大値 + 1
largest_number = pow(2, binary_dim)
# largest_numberまで2進数を用意
binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)
input_layer_size = 2
hidden_layer_size = 16
output_layer_size = 1
weight_init_std = 1
learning_rate = 0.1
iters_num = 10000
plot_interval = 100
# ウェイト初期化 (バイアスは簡単のため省略)
# W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
# W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
# W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)
# Xavier
W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size))
W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size))
W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size))
# He
# W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size)) * np.sqrt(2)
# W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)
# W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)
# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)
u = np.zeros((hidden_layer_size, binary_dim + 1))
z = np.zeros((hidden_layer_size, binary_dim + 1))
y = np.zeros((output_layer_size, binary_dim))
delta_out = np.zeros((output_layer_size, binary_dim))
delta = np.zeros((hidden_layer_size, binary_dim + 1))
all_losses = []
for i in range(iters_num):
# A, B初期化 (a + b = d)
a_int = np.random.randint(largest_number/2)
a_bin = binary[a_int] # binary encoding
b_int = np.random.randint(largest_number/2)
b_bin = binary[b_int] # binary encoding
# 正解データ
d_int = a_int + b_int
d_bin = binary[d_int]
# 出力バイナリ
out_bin = np.zeros_like(d_bin)
# 時系列全体の誤差
all_loss = 0
# 時系列ループ
for t in range(binary_dim):
# 入力値
X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)
# 時刻tにおける正解データ
dd = np.array([d_bin[binary_dim - t - 1]])
u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)
z[:,t+1] = sigmoid(u[:,t+1])
y[:,t] = sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))
# z[:,t+1] = relu(u[:,t+1])
# y[:,t] = relu(np.dot(z[:,t+1].reshape(1, -1), W_out))
z[:,t+1] = np.tanh(u[:,t+1])
y[:,t] = np.tanh(np.dot(z[:,t+1].reshape(1, -1), W_out))
#誤差
loss = mean_squared_error(dd, y[:,t])
delta_out[:,t] = d_mean_squared_error(dd, y[:,t]) * d_sigmoid(y[:,t])
all_loss += loss
out_bin[binary_dim - t - 1] = np.round(y[:,t])
for t in range(binary_dim)[::-1]:
X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1)
delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_sigmoid(u[:,t+1])
# delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_relu(u[:,t+1])
# delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_tanh(u[:,t+1])
# 勾配更新
W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))
W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))
W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))
# 勾配適用
W_in -= learning_rate * W_in_grad
W_out -= learning_rate * W_out_grad
W -= learning_rate * W_grad
W_in_grad *= 0
W_out_grad *= 0
W_grad *= 0
if(i % plot_interval == 0):
all_losses.append(all_loss)
print("iters:" + str(i))
print("Loss:" + str(all_loss))
print("Pred:" + str(out_bin))
print("True:" + str(d_bin))
out_int = 0
for index,x in enumerate(reversed(out_bin)):
out_int += x * pow(2, index)
print(str(a_int) + " + " + str(b_int) + " = " + str(out_int))
print("------------")
lists = range(0, iters_num, plot_interval)
plt.plot(lists, all_losses, label="loss")
plt.show()
iters:0 Loss:2.153564894745493 Pred:[0 0 1 0 0 0 0 0] True:[1 1 0 1 0 1 0 0] 121 + 91 = 32 ------------ iters:100 Loss:0.5659978109753718 Pred:[0 0 1 1 1 1 1 0] True:[0 0 1 1 0 0 1 1] 24 + 27 = 62 ------------ iters:200 Loss:1.5234273549747595 Pred:[0 0 0 0 1 0 0 0] True:[0 1 0 1 1 0 1 1] 5 + 86 = 8 ------------ iters:300 Loss:1.0732684500064333 Pred:[0 0 0 1 1 1 0 0] True:[0 1 1 1 1 1 0 1] 111 + 14 = 28 ------------ iters:400 Loss:1.3702915075700555 Pred:[0 0 0 0 1 1 1 0] True:[0 1 1 1 1 0 0 1] 10 + 111 = 14 ------------ iters:500 Loss:0.9647743604989347 Pred:[0 1 0 0 1 1 0 0] True:[1 0 0 1 0 1 1 0] 70 + 80 = 76 ------------ iters:600 Loss:1.1480619329217434 Pred:[1 1 1 1 1 1 1 0] True:[1 1 1 0 0 0 0 1] 114 + 111 = 254 ------------ iters:700 Loss:1.0431930834757628 Pred:[1 1 1 1 1 0 1 0] True:[1 0 1 1 1 1 0 0] 73 + 115 = 250 ------------ iters:800 Loss:0.8700848545561863 Pred:[0 0 1 1 1 0 0 0] True:[0 1 0 0 0 0 0 0] 24 + 40 = 56 ------------ iters:900 Loss:0.9435390961625657 Pred:[1 1 0 1 0 0 0 0] True:[1 0 0 1 0 1 1 0] 69 + 81 = 208 ------------ iters:1000 Loss:1.2748087371021117 Pred:[0 0 1 1 1 1 0 1] True:[0 1 0 0 1 1 1 0] 25 + 53 = 61 ------------ iters:1100 Loss:1.0681730984674394 Pred:[0 0 1 1 1 0 0 1] True:[0 1 0 0 0 1 0 1] 25 + 44 = 57 ------------ iters:1200 Loss:0.5352576913311647 Pred:[0 1 1 1 0 0 0 0] True:[0 1 1 1 1 0 0 0] 24 + 96 = 112 ------------ iters:1300 Loss:0.6447085338301011 Pred:[0 0 0 0 0 0 0 0] True:[0 0 0 0 1 1 0 1] 3 + 10 = 0 ------------ iters:1400 Loss:0.8743242683856232 Pred:[0 0 1 0 0 0 0 0] True:[0 0 1 1 1 1 0 1] 17 + 44 = 32 ------------ iters:1500 Loss:0.9926579862906834 Pred:[0 0 1 0 0 1 0 0] True:[0 0 1 1 1 1 1 0] 26 + 36 = 36 ------------ iters:1600 Loss:1.1973668032523332 Pred:[1 1 1 1 1 1 0 1] True:[1 1 0 0 0 0 1 1] 86 + 109 = 253 ------------ iters:1700 Loss:0.5712612361030289 Pred:[0 0 1 1 0 1 1 0] True:[0 0 1 0 0 1 1 1] 19 + 20 = 54 ------------ iters:1800 Loss:1.1251359893632809 Pred:[1 0 1 1 1 1 0 0] True:[1 0 1 0 0 0 0 0] 93 + 67 = 188 ------------ iters:1900 Loss:0.6588827075097232 Pred:[1 1 1 1 1 0 1 0] True:[1 0 1 1 1 1 1 0] 68 + 122 = 250 ------------ iters:2000 Loss:0.8459203533313725 Pred:[0 1 0 1 1 1 0 0] True:[1 0 0 1 1 1 0 0] 125 + 31 = 92 ------------ iters:2100 Loss:0.6806428475661361 Pred:[1 1 0 1 1 1 0 0] True:[1 1 0 1 1 0 0 0] 123 + 93 = 220 ------------ iters:2200 Loss:0.28051406170320214 Pred:[0 1 0 0 1 0 0 0] True:[0 1 0 0 1 1 0 0] 41 + 35 = 72 ------------ iters:2300 Loss:0.2678601472715533 Pred:[0 0 1 1 1 0 1 1] True:[0 0 1 1 1 0 0 1] 26 + 31 = 59 ------------ iters:2400 Loss:0.5053430276519022 Pred:[0 1 0 0 1 1 1 1] True:[0 1 0 1 1 1 1 1] 25 + 70 = 79 ------------ iters:2500 Loss:0.4214456423294395 Pred:[1 1 1 0 1 0 0 0] True:[0 1 1 0 1 0 0 0] 69 + 35 = 232 ------------ iters:2600 Loss:0.43138557041268244 Pred:[0 0 1 1 0 0 1 1] True:[0 0 1 1 0 1 1 1] 22 + 33 = 51 ------------ iters:2700 Loss:0.08342940450333307 Pred:[1 1 1 0 1 0 1 0] True:[1 1 1 0 1 0 1 0] 127 + 107 = 234 ------------ iters:2800 Loss:0.9843271892876817 Pred:[1 1 0 1 1 1 0 0] True:[1 0 0 0 1 0 0 0] 17 + 119 = 220 ------------ iters:2900 Loss:0.27041290556212505 Pred:[0 1 1 1 0 1 1 1] True:[0 1 1 1 0 0 1 1] 85 + 30 = 119 ------------ iters:3000 Loss:0.17747346976447112 Pred:[0 0 1 1 1 1 0 1] True:[0 0 1 1 1 1 0 1] 36 + 25 = 61 ------------ iters:3100 Loss:0.15474438566810336 Pred:[0 1 0 0 1 1 0 1] True:[0 1 0 0 1 1 0 1] 31 + 46 = 77 ------------ iters:3200 Loss:0.2925576273470318 Pred:[1 0 1 1 1 1 1 0] True:[1 0 0 1 1 1 1 0] 38 + 120 = 190 ------------ iters:3300 Loss:0.058212540150715524 Pred:[0 1 1 0 0 0 0 1] True:[0 1 1 0 0 0 0 1] 65 + 32 = 97 ------------ iters:3400 Loss:0.23097330361993285 Pred:[0 1 1 1 0 0 0 1] True:[0 1 1 1 0 0 0 1] 70 + 43 = 113 ------------ iters:3500 Loss:0.12528287152008144 Pred:[0 1 0 1 1 0 1 0] True:[0 1 0 1 1 0 1 0] 33 + 57 = 90 ------------ iters:3600 Loss:0.11931292858767471 Pred:[1 0 0 1 0 1 1 1] True:[1 0 0 1 0 1 1 1] 35 + 116 = 151 ------------ iters:3700 Loss:0.03487172172445985 Pred:[0 1 1 0 0 0 0 0] True:[0 1 1 0 0 0 0 0] 24 + 72 = 96 ------------ iters:3800 Loss:0.42386621971691407 Pred:[1 1 0 0 0 1 1 0] True:[1 0 0 0 0 1 1 0] 56 + 78 = 198 ------------ iters:3900 Loss:0.04679461229142709 Pred:[0 1 0 0 1 1 1 0] True:[0 1 0 0 1 1 1 0] 51 + 27 = 78 ------------ iters:4000 Loss:0.08500932952686045 Pred:[0 1 0 1 1 1 0 0] True:[0 1 0 1 1 1 0 0] 16 + 76 = 92 ------------ iters:4100 Loss:0.028449876185565142 Pred:[1 1 1 1 0 0 1 1] True:[1 1 1 1 0 0 1 1] 125 + 118 = 243 ------------ iters:4200 Loss:0.010387937689509383 Pred:[1 0 0 0 0 1 1 1] True:[1 0 0 0 0 1 1 1] 87 + 48 = 135 ------------ iters:4300 Loss:0.16121578777568452 Pred:[1 0 0 0 0 1 0 1] True:[1 0 0 0 0 1 0 1] 55 + 78 = 133 ------------ iters:4400 Loss:0.008219474578897448 Pred:[0 1 1 1 0 1 1 1] True:[0 1 1 1 0 1 1 1] 7 + 112 = 119 ------------ iters:4500 Loss:0.3146766921785453 Pred:[1 1 1 1 1 1 0 0] True:[1 0 1 1 1 1 0 0] 81 + 107 = 252 ------------ iters:4600 Loss:0.023366812397411115 Pred:[1 1 1 0 0 0 1 0] True:[1 1 1 0 0 0 1 0] 107 + 119 = 226 ------------ iters:4700 Loss:0.021387873678467333 Pred:[1 1 0 1 0 1 0 0] True:[1 1 0 1 0 1 0 0] 98 + 114 = 212 ------------ iters:4800 Loss:0.030882954549819937 Pred:[0 1 1 0 1 1 0 1] True:[0 1 1 0 1 1 0 1] 11 + 98 = 109 ------------ iters:4900 Loss:0.12399539928882054 Pred:[0 1 0 0 1 0 1 1] True:[0 1 0 0 1 0 1 1] 61 + 14 = 75 ------------ iters:5000 Loss:0.040624100294295765 Pred:[0 1 1 1 0 1 1 0] True:[0 1 1 1 0 1 1 0] 99 + 19 = 118 ------------ iters:5100 Loss:0.029890623469243406 Pred:[1 0 0 1 1 0 1 1] True:[1 0 0 1 1 0 1 1] 94 + 61 = 155 ------------ iters:5200 Loss:0.006747409869158992 Pred:[1 1 0 1 0 1 0 0] True:[1 1 0 1 0 1 0 0] 104 + 108 = 212 ------------ iters:5300 Loss:0.11362048400667227 Pred:[0 0 1 1 1 1 0 0] True:[0 0 1 1 1 1 0 0] 43 + 17 = 60 ------------ iters:5400 Loss:0.036499078282814464 Pred:[1 0 0 1 0 0 0 1] True:[1 0 0 1 0 0 0 1] 104 + 41 = 145 ------------ iters:5500 Loss:0.0020271442734815285 Pred:[1 0 1 1 1 0 0 0] True:[1 0 1 1 1 0 0 0] 92 + 92 = 184 ------------ iters:5600 Loss:0.019150168217955153 Pred:[0 0 0 1 0 0 1 1] True:[0 0 0 1 0 0 1 1] 14 + 5 = 19 ------------ iters:5700 Loss:0.015615537960966534 Pred:[0 1 1 0 1 1 1 0] True:[0 1 1 0 1 1 1 0] 11 + 99 = 110 ------------ iters:5800 Loss:0.055848845352638045 Pred:[0 1 0 0 0 0 0 0] True:[0 1 0 0 0 0 0 0] 3 + 61 = 64 ------------ iters:5900 Loss:0.027811941122713048 Pred:[1 0 1 1 0 1 0 1] True:[1 0 1 1 0 1 0 1] 85 + 96 = 181 ------------ iters:6000 Loss:0.010294577678977852 Pred:[1 0 1 0 0 0 1 0] True:[1 0 1 0 0 0 1 0] 106 + 56 = 162 ------------ iters:6100 Loss:0.0294434376838914 Pred:[0 1 0 1 0 0 0 0] True:[0 1 0 1 0 0 0 0] 63 + 17 = 80 ------------ iters:6200 Loss:0.09631313639715536 Pred:[0 1 0 0 0 1 0 1] True:[0 1 0 0 0 1 0 1] 24 + 45 = 69 ------------ iters:6300 Loss:0.01924271266456333 Pred:[1 1 0 0 1 0 0 1] True:[1 1 0 0 1 0 0 1] 110 + 91 = 201 ------------ iters:6400 Loss:0.08407179514535484 Pred:[0 1 1 1 1 1 1 0] True:[0 1 1 1 1 1 1 0] 20 + 106 = 126 ------------ iters:6500 Loss:0.004977384528896935 Pred:[1 0 0 0 1 0 0 0] True:[1 0 0 0 1 0 0 0] 36 + 100 = 136 ------------ iters:6600 Loss:0.023494576741571575 Pred:[1 0 0 0 1 1 1 0] True:[1 0 0 0 1 1 1 0] 95 + 47 = 142 ------------ iters:6700 Loss:0.003000975419069599 Pred:[1 0 1 1 0 0 1 1] True:[1 0 1 1 0 0 1 1] 96 + 83 = 179 ------------ iters:6800 Loss:0.0021108360627262197 Pred:[1 0 0 1 1 0 0 1] True:[1 0 0 1 1 0 0 1] 87 + 66 = 153 ------------ iters:6900 Loss:0.043114944696639484 Pred:[0 1 0 1 1 0 1 0] True:[0 1 0 1 1 0 1 0] 63 + 27 = 90 ------------ iters:7000 Loss:0.006124100758764581 Pred:[0 1 1 1 1 0 0 0] True:[0 1 1 1 1 0 0 0] 114 + 6 = 120 ------------ iters:7100 Loss:0.06282776676401762 Pred:[0 1 1 0 0 0 0 1] True:[0 1 1 0 0 0 0 1] 69 + 28 = 97 ------------ iters:7200 Loss:0.012013012155202389 Pred:[0 1 0 0 0 0 0 1] True:[0 1 0 0 0 0 0 1] 32 + 33 = 65 ------------ iters:7300 Loss:0.028836632944875068 Pred:[1 1 0 1 1 0 0 1] True:[1 1 0 1 1 0 0 1] 122 + 95 = 217 ------------ iters:7400 Loss:0.03777987985745057 Pred:[0 1 1 0 0 0 1 1] True:[0 1 1 0 0 0 1 1] 60 + 39 = 99 ------------ iters:7500 Loss:0.0016045064751352685 Pred:[0 1 0 1 1 0 1 0] True:[0 1 0 1 1 0 1 0] 10 + 80 = 90 ------------ iters:7600 Loss:0.011171217721587709 Pred:[1 0 1 0 0 0 0 1] True:[1 0 1 0 0 0 0 1] 125 + 36 = 161 ------------ iters:7700 Loss:0.08617873995102997 Pred:[0 1 0 0 1 1 1 0] True:[0 1 0 0 1 1 1 0] 40 + 38 = 78 ------------ iters:7800 Loss:0.010397438122656017 Pred:[0 1 1 0 1 0 1 0] True:[0 1 1 0 1 0 1 0] 88 + 18 = 106 ------------ iters:7900 Loss:0.02433398525828952 Pred:[1 0 0 0 0 1 0 1] True:[1 0 0 0 0 1 0 1] 77 + 56 = 133 ------------ iters:8000 Loss:0.001718991142316749 Pred:[1 1 0 1 0 1 0 1] True:[1 1 0 1 0 1 0 1] 117 + 96 = 213 ------------ iters:8100 Loss:0.0007114582440403065 Pred:[0 1 1 1 0 1 1 1] True:[0 1 1 1 0 1 1 1] 88 + 31 = 119 ------------ iters:8200 Loss:0.025349840562032013 Pred:[1 1 0 0 1 1 1 1] True:[1 1 0 0 1 1 1 1] 90 + 117 = 207 ------------ iters:8300 Loss:0.021087777651367698 Pred:[0 0 0 0 1 1 1 0] True:[0 0 0 0 1 1 1 0] 5 + 9 = 14 ------------ iters:8400 Loss:0.012730265116910503 Pred:[1 0 1 0 0 0 0 1] True:[1 0 1 0 0 0 0 1] 111 + 50 = 161 ------------ iters:8500 Loss:0.010435261233883586 Pred:[1 0 1 0 1 0 0 1] True:[1 0 1 0 1 0 0 1] 125 + 44 = 169 ------------ iters:8600 Loss:0.06613496188358742 Pred:[0 1 1 1 0 1 0 1] True:[0 1 1 1 0 1 0 1] 109 + 8 = 117 ------------ iters:8700 Loss:0.011993586944276585 Pred:[0 0 1 1 1 0 0 1] True:[0 0 1 1 1 0 0 1] 10 + 47 = 57 ------------ iters:8800 Loss:0.019079905925579396 Pred:[0 0 1 0 0 1 1 1] True:[0 0 1 0 0 1 1 1] 18 + 21 = 39 ------------ iters:8900 Loss:0.013600918569904754 Pred:[1 0 1 0 0 1 1 1] True:[1 0 1 0 0 1 1 1] 75 + 92 = 167 ------------ iters:9000 Loss:0.009621368877386419 Pred:[1 0 0 0 1 0 0 1] True:[1 0 0 0 1 0 0 1] 73 + 64 = 137 ------------ iters:9100 Loss:0.01811070525066136 Pred:[0 1 1 0 1 0 1 0] True:[0 1 1 0 1 0 1 0] 63 + 43 = 106 ------------ iters:9200 Loss:0.004258061377861931 Pred:[1 0 1 1 1 1 1 0] True:[1 0 1 1 1 1 1 0] 64 + 126 = 190 ------------ iters:9300 Loss:0.04397747646965129 Pred:[0 1 1 0 1 0 0 0] True:[0 1 1 0 1 0 0 0] 69 + 35 = 104 ------------ iters:9400 Loss:0.03249117145077845 Pred:[0 1 1 0 0 0 0 1] True:[0 1 1 0 0 0 0 1] 80 + 17 = 97 ------------ iters:9500 Loss:0.02220329506937916 Pred:[1 0 1 1 0 0 0 0] True:[1 0 1 1 0 0 0 0] 70 + 106 = 176 ------------ iters:9600 Loss:0.020627497677411313 Pred:[1 0 1 0 0 0 0 1] True:[1 0 1 0 0 0 0 1] 37 + 124 = 161 ------------ iters:9700 Loss:0.010723291665635452 Pred:[0 1 0 1 0 0 0 0] True:[0 1 0 1 0 0 0 0] 25 + 55 = 80 ------------ iters:9800 Loss:0.026674913935543214 Pred:[1 0 0 0 1 1 1 0] True:[1 0 0 0 1 1 1 0] 102 + 40 = 142 ------------ iters:9900 Loss:0.03234901207213073 Pred:[0 1 1 0 1 0 0 0] True:[0 1 1 0 1 0 0 0] 97 + 7 = 104 ------------
sigmoid(ランダム初期化)の結果
sigmoid(Xavier初期化)
ReLU(ランダム初期化)の結果
ReLU(He初期化)の結果
tanh(ランダム初期化)の結果
tanh(Xavier初期化)の結果
実施する度に結果が少々ばらつくが、どれもうまい具合に収束する。 tanhのランダム初期化だけはあまりうまくいかない。 他のReLUとtanhは初期化にあまり影響を受けていないように見える。
maxlen:2 iters_num:100
maxlen:2 iters_num:500
maxlen:2 iters_num:3000
maxlen:5 iters_num:100
maxlen:5 iters_num:500
maxlen:5 iters_num:3000
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
np.random.seed(0)
# sin曲線
round_num = 10
div_num = 500
ts = np.linspace(0, round_num * np.pi, div_num)
f = np.sin(ts)
def d_tanh(x):
return 1/(np.cosh(x)**2 + 1e-4)
# ひとつの時系列データの長さ
maxlen = 5
# sin波予測の入力データ
test_head = [[f[k]] for k in range(0, maxlen)]
data = []
target = []
for i in range(div_num - maxlen):
data.append(f[i: i + maxlen])
target.append(f[i + maxlen])
X = np.array(data).reshape(len(data), maxlen, 1)
D = np.array(target).reshape(len(data), 1)
# データ設定
N_train = int(len(data) * 0.8)
N_validation = len(data) - N_train
x_train, x_test, d_train, d_test = train_test_split(X, D, test_size=N_validation)
input_layer_size = 1
hidden_layer_size = 5
output_layer_size = 1
weight_init_std = 0.01
learning_rate = 0.1
iters_num = 3000
# ウェイト初期化 (バイアスは簡単のため省略)
W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)
# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)
us = []
zs = []
u = np.zeros(hidden_layer_size)
z = np.zeros(hidden_layer_size)
y = np.zeros(output_layer_size)
delta_out = np.zeros(output_layer_size)
delta = np.zeros(hidden_layer_size)
losses = []
# トレーニング
for i in range(iters_num):
for s in range(x_train.shape[0]):
us.clear()
zs.clear()
z *= 0
# sにおける正解データ
d = d_train[s]
xs = x_train[s]
# 時系列ループ
for t in range(maxlen):
# 入力値
x = xs[t]
u = np.dot(x, W_in) + np.dot(z, W)
us.append(u)
z = np.tanh(u)
zs.append(z)
y = np.dot(z, W_out)
#誤差
loss = mean_squared_error(d, y)
delta_out = d_mean_squared_error(d, y)
delta *= 0
for t in range(maxlen)[::-1]:
delta = (np.dot(delta, W.T) + np.dot(delta_out, W_out.T)) * d_tanh(us[t])
# 勾配更新
W_grad += np.dot(zs[t].reshape(-1,1), delta.reshape(1,-1))
W_in_grad += np.dot(xs[t], delta.reshape(1,-1))
W_out_grad = np.dot(z.reshape(-1,1), delta_out)
# 勾配適用
W -= learning_rate * W_grad
W_in -= learning_rate * W_in_grad
W_out -= learning_rate * W_out_grad.reshape(-1,1)
W_in_grad *= 0
W_out_grad *= 0
W_grad *= 0
# テスト
for s in range(x_test.shape[0]):
z *= 0
# sにおける正解データ
d = d_test[s]
xs = x_test[s]
# 時系列ループ
for t in range(maxlen):
# 入力値
x = xs[t]
u = np.dot(x, W_in) + np.dot(z, W)
z = np.tanh(u)
y = np.dot(z, W_out)
#誤差
loss = mean_squared_error(d, y)
print('loss:', loss, ' d:', d, ' y:', y)
original = np.full(maxlen, None)
pred_num = 200
xs = test_head
# sin波予測
for s in range(0, pred_num):
z *= 0
for t in range(maxlen):
# 入力値
x = xs[t]
u = np.dot(x, W_in) + np.dot(z, W)
z = np.tanh(u)
y = np.dot(z, W_out)
original = np.append(original, y)
xs = np.delete(xs, 0)
xs = np.append(xs, y)
plt.figure()
plt.ylim([-1.5, 1.5])
plt.plot(np.sin(np.linspace(0, round_num* pred_num / div_num * np.pi, pred_num)), linestyle='dotted', color='#aaaaaa')
plt.plot(original, linestyle='dashed', color='black')
plt.show()
loss: 1.0231756688222737e-07 d: [-0.29761864] y: [-0.29716628] loss: 1.2201090156041725e-08 d: [-0.56307233] y: [-0.56322854] loss: 4.038245320901024e-11 d: [-0.65766776] y: [-0.65765877] loss: 1.012613171470193e-08 d: [0.13182648] y: [0.13168417] loss: 6.953992391246068e-08 d: [0.49909101] y: [0.49871807] loss: 5.5838333193039375e-08 d: [0.9518317] y: [0.95149752] loss: 1.006541644062743e-07 d: [0.97784112] y: [0.97739245] loss: 1.7033534573325946e-08 d: [-0.58880346] y: [-0.58861889] loss: 5.2622812685924356e-08 d: [-0.78351093] y: [-0.78383534] loss: 2.60918282588755e-10 d: [-0.49909101] y: [-0.49906816] loss: 5.712849721315155e-08 d: [0.21857331] y: [0.21823529] loss: 9.126545094849141e-08 d: [-0.33938943] y: [-0.3389622] loss: 1.0304339135538811e-07 d: [-0.43793098] y: [-0.43747701] loss: 1.1445752972814562e-07 d: [-0.33346065] y: [-0.3329822] loss: 4.487028896407404e-08 d: [-0.99639027] y: [-0.9960907] loss: 4.518845122030081e-08 d: [0.88624247] y: [0.88654309] loss: 2.3166269315110745e-08 d: [-0.92833248] y: [-0.92811723] loss: 7.877136645016467e-10 d: [-0.52075286] y: [-0.52079255] loss: 8.254988700985761e-09 d: [-0.55262221] y: [-0.5527507] loss: 8.291678208274492e-08 d: [0.47711265] y: [0.47670543] loss: 3.515702303123948e-08 d: [0.60896952] y: [0.60923469] loss: 4.6349673171343807e-08 d: [-0.94587102] y: [-0.94556656] loss: 9.366188781076226e-08 d: [0.27953518] y: [0.27910237] loss: 7.187670231870784e-08 d: [0.73863456] y: [0.73901371] loss: 2.5175272023918393e-08 d: [-0.00629574] y: [-0.00607135] loss: 4.238282860583762e-08 d: [0.54208448] y: [0.54179333] loss: 5.364268767133932e-08 d: [0.99781582] y: [0.99748827] loss: 7.761284716153206e-08 d: [-0.96441607] y: [-0.96402208] loss: 7.29946476567447e-08 d: [0.07547747] y: [0.07509539] loss: 5.374729213271438e-12 d: [0.66239735] y: [0.66240063] loss: 3.917648631674704e-08 d: [-0.54736419] y: [-0.54708428] loss: 1.1396823999698672e-07 d: [0.99393675] y: [0.99345932] loss: 1.8916287872014157e-10 d: [0.96441607] y: [0.96443552] loss: 6.104853811859931e-08 d: [-0.22471249] y: [-0.22436306] loss: 3.374712152402054e-08 d: [-0.99393675] y: [-0.99367695] loss: 7.459208770391463e-08 d: [-0.71705202] y: [-0.71743827] loss: 3.3346599617550096e-09 d: [0.8649742] y: [0.86505587] loss: 1.0502769561892032e-07 d: [-0.11933469] y: [-0.11887638] loss: 3.7402892316320774e-08 d: [-0.40941891] y: [-0.4091454] loss: 1.1571745806193722e-07 d: [0.39789889] y: [0.39741782] loss: 1.3219146512941343e-08 d: [0.98611478] y: [0.98595218] loss: 6.82473798362108e-08 d: [-0.0691982] y: [-0.06882875] loss: 7.74733530193459e-08 d: [0.35709413] y: [0.3567005] loss: 4.2013170888265586e-08 d: [0.99583607] y: [0.9955462] loss: 2.1526965553360173e-08 d: [-0.92597363] y: [-0.92618113] loss: 6.817128658024936e-08 d: [0.36882689] y: [0.36845764] loss: 1.1810524118698761e-08 d: [0.91617219] y: [0.9160185] loss: 1.511514695515008e-09 d: [0.52611726] y: [0.52617224] loss: 8.054071771358459e-08 d: [0.96606148] y: [0.96566013] loss: 1.956018209202484e-08 d: [0.43793098] y: [0.43773319] loss: 2.365114287876332e-09 d: [-0.86811636] y: [-0.86818514] loss: 7.426592869136916e-08 d: [-0.99975723] y: [-0.99937184] loss: 3.369004516388851e-08 d: [0.77562491] y: [0.77588449] loss: 4.2428732558073345e-09 d: [0.04405617] y: [0.04414829] loss: 9.775107382870593e-09 d: [0.71705202] y: [0.71719184] loss: 4.388321808758248e-10 d: [0.8773359] y: [0.87736552] loss: 3.008800038067477e-08 d: [0.91363079] y: [0.9138761] loss: 3.4353135857868203e-09 d: [-0.97784112] y: [-0.97775823] loss: 8.641444917528756e-10 d: [0.96101064] y: [0.96105221] loss: 8.711809080390049e-08 d: [0.26742375] y: [0.26700633] loss: 5.0207765974484416e-08 d: [-0.87122411] y: [-0.871541] loss: 2.838017375633304e-08 d: [-0.91617219] y: [-0.91641043] loss: 1.5567226856080684e-09 d: [0.87122411] y: [0.87127991] loss: 9.803777837499867e-09 d: [0.98394564] y: [0.98380562] loss: 4.14806161320395e-08 d: [0.4036669] y: [0.40337887] loss: 8.246547424877302e-08 d: [0.0880268] y: [0.08762068] loss: 3.519939115353052e-08 d: [0.81012572] y: [0.81039105] loss: 3.3475921950765465e-08 d: [0.41515469] y: [0.41489593] loss: 3.920213266201233e-08 d: [-0.99524241] y: [-0.9949624] loss: 6.676591329593643e-09 d: [-0.94789551] y: [-0.94801107] loss: 2.7259396135763863e-08 d: [0.16916853] y: [0.16893503] loss: 5.901954968651478e-08 d: [-0.95374324] y: [-0.95339967] loss: 7.404264207154143e-08 d: [-0.72577151] y: [-0.72615633] loss: 1.950089841485295e-08 d: [0.74286391] y: [0.7430614] loss: 9.1706563452403e-09 d: [-0.94380904] y: [-0.94394447] loss: 1.8342443304828862e-08 d: [-0.83516734] y: [-0.83535887] loss: 4.0168629526980415e-08 d: [-0.94170965] y: [-0.94142621] loss: 1.0438172705832991e-07 d: [0.32156366] y: [0.32110676] loss: 7.96519652519902e-08 d: [-0.48263615] y: [-0.48223703] loss: 4.522391104134645e-08 d: [0.03776568] y: [0.03746493] loss: 8.671646127145176e-08 d: [0.34530476] y: [0.34488831] loss: 3.0057469610811475e-08 d: [0.56307233] y: [0.56282714] loss: 3.344824485318678e-08 d: [0.90843947] y: [0.90869812] loss: 9.740854716739049e-08 d: [0.99940055] y: [0.99895917] loss: 5.2790241458557934e-08 d: [0.85534252] y: [0.85566745] loss: 1.3438734027154209e-08 d: [0.93739898] y: [0.93756292] loss: 1.0825441262363102e-07 d: [0.99738016] y: [0.99691486] loss: 7.346796784291748e-08 d: [-0.6992734] y: [-0.69965672] loss: 9.62971935624344e-08 d: [-0.10682399] y: [-0.10638513] loss: 4.974432822949506e-09 d: [-0.54208448] y: [-0.54218422] loss: 7.137974552014359e-08 d: [0.9995987] y: [0.99922087] loss: 1.1991319740559976e-07 d: [0.29761864] y: [0.29712892] loss: 3.111699813899164e-08 d: [0.99322482] y: [0.99297535] loss: 9.573921959195159e-08 d: [0.33346065] y: [0.33302307] loss: 5.1818369725854564e-08 d: [-0.83168816] y: [-0.83201008] loss: 1.1452214926486475e-08 d: [-0.98504973] y: [-0.98489839] loss: 5.381547067764966e-08 d: [-0.64332332] y: [-0.6436514] loss: 1.0539082479572855e-07 d: [0.43226238] y: [0.43180327] loss: 3.26643400999611e-08 d: [-0.81380058] y: [-0.81405617]
maxlen:2 iters_num:100
maxlen:2 iters_num:500
maxlen:2 iters_num:3000
maxlen:5 iters_num:100
maxlen:5 iters_num:500
maxlen:5 iters_num:3000
maxlen:2 のときは学習がうまく進まない。iters_numを増やしてもうまくいかない。 maxlen:5 にすると、iters_num:3000でほぼ正確なサイン波を予測できている。