5章のRNNは比較的シンプルなRNNであり、実際にはRNNといえばLSTMやGRUであることが多い。
BPTT(Backpropagation Through Time)において、勾配消失もしくは勾配爆発が起きることが問題になる。
過去の情報を記録される$h_t$はRNNの 隠れ状態(hidden state) と呼ばれる。
RNNが学習するとき、遠いコンテキストに学習結果を伝える必要がある。 シンプルなRNNを使用すると、時間をさかのぼるに従って
という状態が発生する。
これは、RNNの逆伝搬において$\tanh$が使用されており、この$\tanh$を通過するたびに勾配は弱められることになる。 さらに、MatMulノードを通過するたびに勾配の大きさの変化は大きくなり、勾配爆発または勾配消失が発生する。
import numpy as np
import matplotlib.pyplot as plt
N = 2 # ミニバッチサイズ
H = 3 # 隠れ状態ベクトルの次元数
T = 20 # 時系列データの長さ
dh = np.ones((N, H))
np.random.seed(3) # 再現性のため乱数のシードを固定
Wh = np.random.randn(H, H)
norm_list = []
for t in range(T):
dh = np.dot(dh, Wh.T)
norm = np.sqrt(np.sum(dh**2)) / N
norm_list.append(norm)
plt.plot( norm)
勾配クリッピング(gradients clipping)と呼ばれる手法。しきい値を超えた場合は、値を小さくする。
def clip_grads(grads, max_norm):
total_norm = 0
for grad in grads:
total_norm += np.sum(grad ** 2)
total_norm = np.sqrt(total_norm)
rate = max_norm / (total_norm + 1e-6)
if rate < 1:
for grad in grads:
grad *= rate
class LSTM:
...
def forward(self, x, h_prev, c_prev):
Wx, Wh, b = self.params
N, H = h_prev.shape
A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b
f = A[:, :H]
g = A[:, H:2*H]
i = A[:, 2*H:3*H]
o = A[:, 3*H:]
f = sigmoid(f)
g = np.tanh(g)
i = sigmoid(i)
o = sigmoid(o)
c_next = f * c_prev + g * i
h_next = o * np.tanh(c_next)
self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)
return h_next, c_next
これまでのTimeRNNレイヤだった場所に、TimeLSTMを挿入する。
# coding: utf-8
import sys
sys.path.append('..')
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb
from rnnlm import Rnnlm
# ハイパーパラメータの設定
batch_size = 20
wordvec_size = 100
hidden_size = 100 # RNNの隠れ状態ベクトルの要素数
time_size = 35 # RNNを展開するサイズ
lr = 20.0
max_epoch = 4
max_grad = 0.25
# 学習データの読み込み
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]
# モデルの生成
model = Rnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)
# 勾配クリッピングを適用して学習
trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad,
eval_interval=20)
trainer.plot(ylim=(0, 500))
# テストデータで評価
model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)
# パラメータの保存
model.save_params()
| epoch 1 | iter 1 / 1327 | time 0[s] | perplexity 10001.53 | epoch 1 | iter 21 / 1327 | time 4[s] | perplexity 3369.27 | epoch 1 | iter 41 / 1327 | time 8[s] | perplexity 1270.10 | epoch 1 | iter 61 / 1327 | time 12[s] | perplexity 985.67 | epoch 1 | iter 81 / 1327 | time 16[s] | perplexity 794.73 | epoch 1 | iter 101 / 1327 | time 20[s] | perplexity 680.45 | epoch 1 | iter 121 / 1327 | time 24[s] | perplexity 658.45 | epoch 1 | iter 141 / 1327 | time 29[s] | perplexity 619.76 | epoch 1 | iter 161 / 1327 | time 32[s] | perplexity 606.06 | epoch 1 | iter 181 / 1327 | time 36[s] | perplexity 584.49 | epoch 1 | iter 201 / 1327 | time 40[s] | perplexity 521.27 | epoch 1 | iter 221 / 1327 | time 44[s] | perplexity 503.49 | epoch 1 | iter 241 / 1327 | time 49[s] | perplexity 450.23 | epoch 1 | iter 261 / 1327 | time 53[s] | perplexity 478.97 | epoch 1 | iter 281 / 1327 | time 57[s] | perplexity 448.23 | epoch 1 | iter 301 / 1327 | time 61[s] | perplexity 393.59 | epoch 1 | iter 321 / 1327 | time 65[s] | perplexity 358.08 | epoch 1 | iter 341 / 1327 | time 69[s] | perplexity 404.20 | epoch 1 | iter 361 / 1327 | time 73[s] | perplexity 413.65 | epoch 1 | iter 381 / 1327 | time 77[s] | perplexity 344.77 | epoch 1 | iter 401 / 1327 | time 81[s] | perplexity 360.71 | epoch 1 | iter 421 / 1327 | time 85[s] | perplexity 348.61 | epoch 1 | iter 441 / 1327 | time 89[s] | perplexity 333.02 | epoch 1 | iter 461 / 1327 | time 93[s] | perplexity 326.15 | epoch 1 | iter 481 / 1327 | time 97[s] | perplexity 309.48 | epoch 1 | iter 501 / 1327 | time 101[s] | perplexity 320.86 | epoch 1 | iter 521 / 1327 | time 105[s] | perplexity 304.07 | epoch 1 | iter 541 / 1327 | time 109[s] | perplexity 317.89 | epoch 1 | iter 561 / 1327 | time 113[s] | perplexity 289.84 | epoch 1 | iter 581 / 1327 | time 118[s] | perplexity 260.17 | epoch 1 | iter 601 / 1327 | time 125[s] | perplexity 336.76 | epoch 1 | iter 621 / 1327 | time 131[s] | perplexity 317.84 | epoch 1 | iter 641 / 1327 | time 136[s] | perplexity 283.07 | epoch 1 | iter 661 / 1327 | time 140[s] | perplexity 271.58 | epoch 1 | iter 681 / 1327 | time 146[s] | perplexity 230.88 | epoch 1 | iter 701 / 1327 | time 152[s] | perplexity 250.57 | epoch 1 | iter 721 / 1327 | time 157[s] | perplexity 259.89 | epoch 1 | iter 741 / 1327 | time 161[s] | perplexity 221.75 | epoch 1 | iter 761 / 1327 | time 165[s] | perplexity 234.70 | epoch 1 | iter 781 / 1327 | time 169[s] | perplexity 219.92 | epoch 1 | iter 801 / 1327 | time 173[s] | perplexity 240.31 | epoch 1 | iter 821 / 1327 | time 177[s] | perplexity 224.74 | epoch 1 | iter 841 / 1327 | time 181[s] | perplexity 229.62 | epoch 1 | iter 861 / 1327 | time 186[s] | perplexity 222.55 | epoch 1 | iter 881 / 1327 | time 191[s] | perplexity 205.47 | epoch 1 | iter 901 / 1327 | time 196[s] | perplexity 256.40 | epoch 1 | iter 921 / 1327 | time 201[s] | perplexity 229.65 | epoch 1 | iter 941 / 1327 | time 207[s] | perplexity 229.20 | epoch 1 | iter 961 / 1327 | time 212[s] | perplexity 245.76 | epoch 1 | iter 981 / 1327 | time 218[s] | perplexity 230.29 | epoch 1 | iter 1001 / 1327 | time 223[s] | perplexity 193.78 | epoch 1 | iter 1021 / 1327 | time 228[s] | perplexity 226.22 | epoch 1 | iter 1041 / 1327 | time 233[s] | perplexity 209.86 | epoch 1 | iter 1061 / 1327 | time 238[s] | perplexity 198.70 | epoch 1 | iter 1081 / 1327 | time 243[s] | perplexity 168.97 | epoch 1 | iter 1101 / 1327 | time 248[s] | perplexity 192.36 | epoch 1 | iter 1121 / 1327 | time 253[s] | perplexity 229.63 | epoch 1 | iter 1141 / 1327 | time 258[s] | perplexity 206.56 | epoch 1 | iter 1161 / 1327 | time 262[s] | perplexity 199.81 | epoch 1 | iter 1181 / 1327 | time 267[s] | perplexity 191.34 | epoch 1 | iter 1201 / 1327 | time 272[s] | perplexity 163.47 | epoch 1 | iter 1221 / 1327 | time 277[s] | perplexity 161.37 | epoch 1 | iter 1241 / 1327 | time 281[s] | perplexity 187.98 | epoch 1 | iter 1261 / 1327 | time 286[s] | perplexity 172.95 | epoch 1 | iter 1281 / 1327 | time 291[s] | perplexity 180.20 | epoch 1 | iter 1301 / 1327 | time 296[s] | perplexity 222.08 | epoch 1 | iter 1321 / 1327 | time 301[s] | perplexity 210.05 | epoch 2 | iter 1 / 1327 | time 302[s] | perplexity 223.75 | epoch 2 | iter 21 / 1327 | time 307[s] | perplexity 203.74 | epoch 2 | iter 41 / 1327 | time 312[s] | perplexity 189.99 | epoch 2 | iter 61 / 1327 | time 317[s] | perplexity 177.51 | epoch 2 | iter 81 / 1327 | time 322[s] | perplexity 160.12 | epoch 2 | iter 101 / 1327 | time 327[s] | perplexity 152.91 | epoch 2 | iter 121 / 1327 | time 332[s] | perplexity 160.44 | epoch 2 | iter 141 / 1327 | time 337[s] | perplexity 178.88 | epoch 2 | iter 161 / 1327 | time 341[s] | perplexity 193.38 | epoch 2 | iter 181 / 1327 | time 346[s] | perplexity 199.88 | epoch 2 | iter 201 / 1327 | time 350[s] | perplexity 184.36 | epoch 2 | iter 221 / 1327 | time 354[s] | perplexity 183.61 | epoch 2 | iter 241 / 1327 | time 359[s] | perplexity 177.71 | epoch 2 | iter 261 / 1327 | time 364[s] | perplexity 185.46 | epoch 2 | iter 281 / 1327 | time 368[s] | perplexity 185.07 | epoch 2 | iter 301 / 1327 | time 373[s] | perplexity 166.49 | epoch 2 | iter 321 / 1327 | time 377[s] | perplexity 139.45 | epoch 2 | iter 341 / 1327 | time 382[s] | perplexity 173.59 | epoch 2 | iter 361 / 1327 | time 386[s] | perplexity 196.56 | epoch 2 | iter 381 / 1327 | time 390[s] | perplexity 152.65 | epoch 2 | iter 401 / 1327 | time 395[s] | perplexity 167.52 | epoch 2 | iter 421 / 1327 | time 399[s] | perplexity 153.49 | epoch 2 | iter 441 / 1327 | time 404[s] | perplexity 162.79 | epoch 2 | iter 461 / 1327 | time 408[s] | perplexity 158.78 | epoch 2 | iter 481 / 1327 | time 412[s] | perplexity 156.82 | epoch 2 | iter 501 / 1327 | time 417[s] | perplexity 168.73 | epoch 2 | iter 521 / 1327 | time 421[s] | perplexity 174.60 | epoch 2 | iter 541 / 1327 | time 426[s] | perplexity 175.13 | epoch 2 | iter 561 / 1327 | time 430[s] | perplexity 154.33 | epoch 2 | iter 581 / 1327 | time 435[s] | perplexity 138.94 | epoch 2 | iter 601 / 1327 | time 439[s] | perplexity 190.57 | epoch 2 | iter 621 / 1327 | time 444[s] | perplexity 181.79 | epoch 2 | iter 641 / 1327 | time 448[s] | perplexity 164.16 | epoch 2 | iter 661 / 1327 | time 452[s] | perplexity 154.69 | epoch 2 | iter 681 / 1327 | time 457[s] | perplexity 129.25 | epoch 2 | iter 701 / 1327 | time 461[s] | perplexity 149.68 | epoch 2 | iter 721 / 1327 | time 466[s] | perplexity 160.68 | epoch 2 | iter 741 / 1327 | time 470[s] | perplexity 132.86 | epoch 2 | iter 761 / 1327 | time 475[s] | perplexity 130.31 | epoch 2 | iter 781 / 1327 | time 479[s] | perplexity 135.17 | epoch 2 | iter 801 / 1327 | time 484[s] | perplexity 147.12 | epoch 2 | iter 821 / 1327 | time 488[s] | perplexity 143.79 | epoch 2 | iter 841 / 1327 | time 493[s] | perplexity 143.31 | epoch 2 | iter 861 / 1327 | time 498[s] | perplexity 144.79 | epoch 2 | iter 881 / 1327 | time 503[s] | perplexity 131.04 | epoch 2 | iter 901 / 1327 | time 508[s] | perplexity 165.02 | epoch 2 | iter 921 / 1327 | time 513[s] | perplexity 148.06 | epoch 2 | iter 941 / 1327 | time 518[s] | perplexity 153.83 | epoch 2 | iter 961 / 1327 | time 523[s] | perplexity 165.04 | epoch 2 | iter 981 / 1327 | time 528[s] | perplexity 153.31 | epoch 2 | iter 1001 / 1327 | time 533[s] | perplexity 132.15 | epoch 2 | iter 1021 / 1327 | time 538[s] | perplexity 156.56 | epoch 2 | iter 1041 / 1327 | time 543[s] | perplexity 141.93 | epoch 2 | iter 1061 / 1327 | time 548[s] | perplexity 128.13 | epoch 2 | iter 1081 / 1327 | time 554[s] | perplexity 110.03 | epoch 2 | iter 1101 / 1327 | time 559[s] | perplexity 119.79 | epoch 2 | iter 1121 / 1327 | time 563[s] | perplexity 152.99 | epoch 2 | iter 1141 / 1327 | time 568[s] | perplexity 141.46 | epoch 2 | iter 1161 / 1327 | time 572[s] | perplexity 133.02 | epoch 2 | iter 1181 / 1327 | time 577[s] | perplexity 133.30 | epoch 2 | iter 1201 / 1327 | time 582[s] | perplexity 112.68 | epoch 2 | iter 1221 / 1327 | time 587[s] | perplexity 109.20 | epoch 2 | iter 1241 / 1327 | time 591[s] | perplexity 130.53 | epoch 2 | iter 1261 / 1327 | time 596[s] | perplexity 124.27 | epoch 2 | iter 1281 / 1327 | time 600[s] | perplexity 122.96 | epoch 2 | iter 1301 / 1327 | time 605[s] | perplexity 157.10 | epoch 2 | iter 1321 / 1327 | time 609[s] | perplexity 153.77 | epoch 3 | iter 1 / 1327 | time 611[s] | perplexity 159.04 | epoch 3 | iter 21 / 1327 | time 615[s] | perplexity 144.00 | epoch 3 | iter 41 / 1327 | time 620[s] | perplexity 135.04 | epoch 3 | iter 61 / 1327 | time 625[s] | perplexity 126.73 | epoch 3 | iter 81 / 1327 | time 629[s] | perplexity 116.52 | epoch 3 | iter 101 / 1327 | time 634[s] | perplexity 105.99 | epoch 3 | iter 121 / 1327 | time 638[s] | perplexity 116.02 | epoch 3 | iter 141 / 1327 | time 643[s] | perplexity 126.64 | epoch 3 | iter 161 / 1327 | time 647[s] | perplexity 142.06 | epoch 3 | iter 181 / 1327 | time 652[s] | perplexity 148.62 | epoch 3 | iter 201 / 1327 | time 656[s] | perplexity 141.37 | epoch 3 | iter 221 / 1327 | time 660[s] | perplexity 140.41 | epoch 3 | iter 241 / 1327 | time 665[s] | perplexity 135.36 | epoch 3 | iter 261 / 1327 | time 670[s] | perplexity 139.09 | epoch 3 | iter 281 / 1327 | time 674[s] | perplexity 141.31 | epoch 3 | iter 301 / 1327 | time 679[s] | perplexity 123.63 | epoch 3 | iter 321 / 1327 | time 683[s] | perplexity 101.11 | epoch 3 | iter 341 / 1327 | time 687[s] | perplexity 123.65 | epoch 3 | iter 361 / 1327 | time 692[s] | perplexity 151.25 | epoch 3 | iter 381 / 1327 | time 696[s] | perplexity 114.11 | epoch 3 | iter 401 / 1327 | time 701[s] | perplexity 129.45 | epoch 3 | iter 421 / 1327 | time 705[s] | perplexity 113.48 | epoch 3 | iter 441 / 1327 | time 710[s] | perplexity 123.47 | epoch 3 | iter 461 / 1327 | time 714[s] | perplexity 119.60 | epoch 3 | iter 481 / 1327 | time 719[s] | perplexity 118.64 | epoch 3 | iter 501 / 1327 | time 723[s] | perplexity 128.42 | epoch 3 | iter 521 / 1327 | time 728[s] | perplexity 138.72 | epoch 3 | iter 541 / 1327 | time 732[s] | perplexity 135.22 | epoch 3 | iter 561 / 1327 | time 737[s] | perplexity 117.48 | epoch 3 | iter 581 / 1327 | time 742[s] | perplexity 105.71 | epoch 3 | iter 601 / 1327 | time 746[s] | perplexity 147.57 | epoch 3 | iter 621 / 1327 | time 751[s] | perplexity 141.62 | epoch 3 | iter 641 / 1327 | time 755[s] | perplexity 129.66 | epoch 3 | iter 661 / 1327 | time 759[s] | perplexity 120.43 | epoch 3 | iter 681 / 1327 | time 764[s] | perplexity 99.94 | epoch 3 | iter 701 / 1327 | time 768[s] | perplexity 118.26 | epoch 3 | iter 721 / 1327 | time 773[s] | perplexity 126.15 | epoch 3 | iter 741 / 1327 | time 778[s] | perplexity 106.56 | epoch 3 | iter 761 / 1327 | time 782[s] | perplexity 103.61 | epoch 3 | iter 781 / 1327 | time 787[s] | perplexity 103.21 | epoch 3 | iter 801 / 1327 | time 791[s] | perplexity 114.00 | epoch 3 | iter 821 / 1327 | time 796[s] | perplexity 115.18 | epoch 3 | iter 841 / 1327 | time 800[s] | perplexity 114.29 | epoch 3 | iter 861 / 1327 | time 805[s] | perplexity 118.70 | epoch 3 | iter 881 / 1327 | time 809[s] | perplexity 106.59 | epoch 3 | iter 901 / 1327 | time 814[s] | perplexity 130.60 | epoch 3 | iter 921 / 1327 | time 818[s] | perplexity 119.56 | epoch 3 | iter 941 / 1327 | time 822[s] | perplexity 126.76 | epoch 3 | iter 961 / 1327 | time 827[s] | perplexity 132.62 | epoch 3 | iter 981 / 1327 | time 831[s] | perplexity 122.90 | epoch 3 | iter 1001 / 1327 | time 836[s] | perplexity 109.57 | epoch 3 | iter 1021 / 1327 | time 842[s] | perplexity 128.46 | epoch 3 | iter 1041 / 1327 | time 848[s] | perplexity 118.44 | epoch 3 | iter 1061 / 1327 | time 853[s] | perplexity 102.22 | epoch 3 | iter 1081 / 1327 | time 857[s] | perplexity 87.85 | epoch 3 | iter 1101 / 1327 | time 864[s] | perplexity 95.52 | epoch 3 | iter 1121 / 1327 | time 869[s] | perplexity 120.50 | epoch 3 | iter 1141 / 1327 | time 874[s] | perplexity 114.41 | epoch 3 | iter 1161 / 1327 | time 880[s] | perplexity 107.33 | epoch 3 | iter 1181 / 1327 | time 885[s] | perplexity 110.84 | epoch 3 | iter 1201 / 1327 | time 890[s] | perplexity 93.96 | epoch 3 | iter 1221 / 1327 | time 895[s] | perplexity 88.58 | epoch 3 | iter 1241 / 1327 | time 900[s] | perplexity 105.31 | epoch 3 | iter 1261 / 1327 | time 904[s] | perplexity 105.53 | epoch 3 | iter 1281 / 1327 | time 909[s] | perplexity 100.72 | epoch 3 | iter 1301 / 1327 | time 914[s] | perplexity 130.49 | epoch 3 | iter 1321 / 1327 | time 919[s] | perplexity 127.11 | epoch 4 | iter 1 / 1327 | time 920[s] | perplexity 132.73 | epoch 4 | iter 21 / 1327 | time 925[s] | perplexity 121.07 | epoch 4 | iter 41 / 1327 | time 929[s] | perplexity 106.88 | epoch 4 | iter 61 / 1327 | time 934[s] | perplexity 106.00 | epoch 4 | iter 81 / 1327 | time 939[s] | perplexity 95.54 | epoch 4 | iter 101 / 1327 | time 944[s] | perplexity 86.23 | epoch 4 | iter 121 / 1327 | time 949[s] | perplexity 94.56 | epoch 4 | iter 141 / 1327 | time 953[s] | perplexity 103.08 | epoch 4 | iter 161 / 1327 | time 958[s] | perplexity 118.39 | epoch 4 | iter 181 / 1327 | time 963[s] | perplexity 127.76 | epoch 4 | iter 201 / 1327 | time 967[s] | perplexity 119.95 | epoch 4 | iter 221 / 1327 | time 971[s] | perplexity 121.68 | epoch 4 | iter 241 / 1327 | time 976[s] | perplexity 114.74 | epoch 4 | iter 261 / 1327 | time 981[s] | perplexity 114.72 | epoch 4 | iter 281 / 1327 | time 985[s] | perplexity 120.67 | epoch 4 | iter 301 / 1327 | time 990[s] | perplexity 103.55 | epoch 4 | iter 321 / 1327 | time 995[s] | perplexity 83.56 | epoch 4 | iter 341 / 1327 | time 999[s] | perplexity 100.04 | epoch 4 | iter 361 / 1327 | time 1004[s] | perplexity 127.70 | epoch 4 | iter 381 / 1327 | time 1009[s] | perplexity 96.72 | epoch 4 | iter 401 / 1327 | time 1013[s] | perplexity 109.85 | epoch 4 | iter 421 / 1327 | time 1018[s] | perplexity 94.10 | epoch 4 | iter 441 / 1327 | time 1022[s] | perplexity 102.51 | epoch 4 | iter 461 / 1327 | time 1027[s] | perplexity 99.95 | epoch 4 | iter 481 / 1327 | time 1032[s] | perplexity 101.84 | epoch 4 | iter 501 / 1327 | time 1036[s] | perplexity 108.18 | epoch 4 | iter 521 / 1327 | time 1041[s] | perplexity 117.74 | epoch 4 | iter 541 / 1327 | time 1045[s] | perplexity 111.75 | epoch 4 | iter 561 / 1327 | time 1050[s] | perplexity 101.40 | epoch 4 | iter 581 / 1327 | time 1055[s] | perplexity 89.76 | epoch 4 | iter 601 / 1327 | time 1060[s] | perplexity 126.00 | epoch 4 | iter 621 / 1327 | time 1064[s] | perplexity 120.89 | epoch 4 | iter 641 / 1327 | time 1069[s] | perplexity 109.96 | epoch 4 | iter 661 / 1327 | time 1074[s] | perplexity 102.95 | epoch 4 | iter 681 / 1327 | time 1079[s] | perplexity 84.85 | epoch 4 | iter 701 / 1327 | time 1084[s] | perplexity 101.62 | epoch 4 | iter 721 / 1327 | time 1089[s] | perplexity 107.78 | epoch 4 | iter 741 / 1327 | time 1094[s] | perplexity 95.20 | epoch 4 | iter 761 / 1327 | time 1098[s] | perplexity 88.72 | epoch 4 | iter 781 / 1327 | time 1103[s] | perplexity 87.57 | epoch 4 | iter 801 / 1327 | time 1108[s] | perplexity 97.64 | epoch 4 | iter 821 / 1327 | time 1112[s] | perplexity 102.00 | epoch 4 | iter 841 / 1327 | time 1116[s] | perplexity 98.01 | epoch 4 | iter 861 / 1327 | time 1121[s] | perplexity 103.25 | epoch 4 | iter 881 / 1327 | time 1125[s] | perplexity 92.42 | epoch 4 | iter 901 / 1327 | time 1130[s] | perplexity 114.35 | epoch 4 | iter 921 / 1327 | time 1134[s] | perplexity 104.32 | epoch 4 | iter 941 / 1327 | time 1139[s] | perplexity 112.19 | epoch 4 | iter 961 / 1327 | time 1143[s] | perplexity 112.16 | epoch 4 | iter 981 / 1327 | time 1148[s] | perplexity 106.37 | epoch 4 | iter 1001 / 1327 | time 1152[s] | perplexity 97.07 | epoch 4 | iter 1021 / 1327 | time 1157[s] | perplexity 112.89 | epoch 4 | iter 1041 / 1327 | time 1161[s] | perplexity 103.65 | epoch 4 | iter 1061 / 1327 | time 1166[s] | perplexity 88.32 | epoch 4 | iter 1081 / 1327 | time 1170[s] | perplexity 77.75 | epoch 4 | iter 1101 / 1327 | time 1175[s] | perplexity 79.79 | epoch 4 | iter 1121 / 1327 | time 1179[s] | perplexity 102.84 | epoch 4 | iter 1141 / 1327 | time 1183[s] | perplexity 99.15 | epoch 4 | iter 1161 / 1327 | time 1188[s] | perplexity 91.90 | epoch 4 | iter 1181 / 1327 | time 1192[s] | perplexity 95.42 | epoch 4 | iter 1201 / 1327 | time 1197[s] | perplexity 83.07 | epoch 4 | iter 1221 / 1327 | time 1201[s] | perplexity 76.08 | epoch 4 | iter 1241 / 1327 | time 1206[s] | perplexity 91.78 | epoch 4 | iter 1261 / 1327 | time 1210[s] | perplexity 94.18 | epoch 4 | iter 1281 / 1327 | time 1215[s] | perplexity 88.94 | epoch 4 | iter 1301 / 1327 | time 1219[s] | perplexity 111.88 | epoch 4 | iter 1321 / 1327 | time 1223[s] | perplexity 110.80
evaluating perplexity ... 234 / 235 test perplexity: 135.81750561235523
# coding: utf-8
import sys
sys.path.append('..')
from common import config
# GPUで実行する場合は下記のコメントアウトを消去(要cupy)
# ==============================================
# config.GPU = True
# ==============================================
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity, to_gpu
from dataset import ptb
from better_rnnlm import BetterRnnlm
# ハイパーパラメータの設定
batch_size = 20
wordvec_size = 650
hidden_size = 650
time_size = 35
lr = 20.0
max_epoch = 40
max_grad = 0.25
dropout = 0.5
# 学習データの読み込み
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_val, _, _ = ptb.load_data('val')
corpus_test, _, _ = ptb.load_data('test')
if config.GPU:
corpus = to_gpu(corpus)
corpus_val = to_gpu(corpus_val)
corpus_test = to_gpu(corpus_test)
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]
model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)
best_ppl = float('inf')
for epoch in range(max_epoch):
trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size,
time_size=time_size, max_grad=max_grad)
model.reset_state()
ppl = eval_perplexity(model, corpus_val)
print('valid perplexity: ', ppl)
if best_ppl > ppl:
best_ppl = ppl
model.save_params()
else:
lr /= 4.0
optimizer.lr = lr
model.reset_state()
print('-' * 50)
# テストデータでの評価
model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)
| epoch 1 | iter 1 / 1327 | time 8[s] | perplexity 10000.15 | epoch 1 | iter 21 / 1327 | time 93[s] | perplexity 4234.82 | epoch 1 | iter 41 / 1327 | time 178[s] | perplexity 1896.71 | epoch 1 | iter 61 / 1327 | time 272[s] | perplexity 1280.45 | epoch 1 | iter 81 / 1327 | time 332[s] | perplexity 1023.16 | epoch 1 | iter 101 / 1327 | time 394[s] | perplexity 831.45 | epoch 1 | iter 121 / 1327 | time 458[s] | perplexity 807.89 | epoch 1 | iter 141 / 1327 | time 522[s] | perplexity 720.43 | epoch 1 | iter 161 / 1327 | time 587[s] | perplexity 689.12 | epoch 1 | iter 181 / 1327 | time 651[s] | perplexity 679.70 | epoch 1 | iter 201 / 1327 | time 716[s] | perplexity 602.49 | epoch 1 | iter 221 / 1327 | time 781[s] | perplexity 567.63 | epoch 1 | iter 241 / 1327 | time 844[s] | perplexity 528.17 | epoch 1 | iter 261 / 1327 | time 909[s] | perplexity 538.42 | epoch 1 | iter 281 / 1327 | time 975[s] | perplexity 521.44 | epoch 1 | iter 301 / 1327 | time 1043[s] | perplexity 449.47 | epoch 1 | iter 321 / 1327 | time 1112[s] | perplexity 399.01 | epoch 1 | iter 341 / 1327 | time 1180[s] | perplexity 452.80 | epoch 1 | iter 361 / 1327 | time 1247[s] | perplexity 460.58 | epoch 1 | iter 381 / 1327 | time 1312[s] | perplexity 383.68 | epoch 1 | iter 401 / 1327 | time 1379[s] | perplexity 404.59 | epoch 1 | iter 421 / 1327 | time 1445[s] | perplexity 394.29 | epoch 1 | iter 441 / 1327 | time 1509[s] | perplexity 375.40 | epoch 1 | iter 461 / 1327 | time 1574[s] | perplexity 373.20 | epoch 1 | iter 481 / 1327 | time 1639[s] | perplexity 344.29 | epoch 1 | iter 501 / 1327 | time 1682[s] | perplexity 355.09 | epoch 1 | iter 521 / 1327 | time 1722[s] | perplexity 345.81 | epoch 1 | iter 541 / 1327 | time 1761[s] | perplexity 364.67 | epoch 1 | iter 561 / 1327 | time 1798[s] | perplexity 323.76 | epoch 1 | iter 581 / 1327 | time 1836[s] | perplexity 293.38 | epoch 1 | iter 601 / 1327 | time 1873[s] | perplexity 377.00 | epoch 1 | iter 621 / 1327 | time 1911[s] | perplexity 346.04 | epoch 1 | iter 641 / 1327 | time 1948[s] | perplexity 315.85 | epoch 1 | iter 661 / 1327 | time 1985[s] | perplexity 307.43 | epoch 1 | iter 681 / 1327 | time 2023[s] | perplexity 257.17 | epoch 1 | iter 701 / 1327 | time 2062[s] | perplexity 281.79 | epoch 1 | iter 721 / 1327 | time 2101[s] | perplexity 289.17 | epoch 1 | iter 741 / 1327 | time 2140[s] | perplexity 249.45 | epoch 1 | iter 761 / 1327 | time 2179[s] | perplexity 258.53 | epoch 1 | iter 781 / 1327 | time 2216[s] | perplexity 244.86 | epoch 1 | iter 801 / 1327 | time 2254[s] | perplexity 269.02 | epoch 1 | iter 821 / 1327 | time 2291[s] | perplexity 248.95 | epoch 1 | iter 841 / 1327 | time 2328[s] | perplexity 254.58 | epoch 1 | iter 861 / 1327 | time 2367[s] | perplexity 249.46 | epoch 1 | iter 881 / 1327 | time 2405[s] | perplexity 230.01 | epoch 1 | iter 901 / 1327 | time 2442[s] | perplexity 280.32 | epoch 1 | iter 921 / 1327 | time 2480[s] | perplexity 253.94 | epoch 1 | iter 941 / 1327 | time 2516[s] | perplexity 257.12 | epoch 1 | iter 961 / 1327 | time 2553[s] | perplexity 275.56 | epoch 1 | iter 981 / 1327 | time 2590[s] | perplexity 256.14 | epoch 1 | iter 1001 / 1327 | time 2627[s] | perplexity 215.38 | epoch 1 | iter 1021 / 1327 | time 2664[s] | perplexity 251.41 | epoch 1 | iter 1041 / 1327 | time 2702[s] | perplexity 228.57 | epoch 1 | iter 1061 / 1327 | time 2739[s] | perplexity 218.85 | epoch 1 | iter 1081 / 1327 | time 2776[s] | perplexity 188.14 | epoch 1 | iter 1101 / 1327 | time 2813[s] | perplexity 215.89 | epoch 1 | iter 1121 / 1327 | time 2854[s] | perplexity 255.60 | epoch 1 | iter 1141 / 1327 | time 2897[s] | perplexity 229.44 | epoch 1 | iter 1161 / 1327 | time 2941[s] | perplexity 221.68 | epoch 1 | iter 1181 / 1327 | time 2981[s] | perplexity 210.42 | epoch 1 | iter 1201 / 1327 | time 3020[s] | perplexity 181.10 | epoch 1 | iter 1221 / 1327 | time 3065[s] | perplexity 177.75 | epoch 1 | iter 1241 / 1327 | time 3110[s] | perplexity 209.31 | epoch 1 | iter 1261 / 1327 | time 3154[s] | perplexity 191.01 | epoch 1 | iter 1281 / 1327 | time 3203[s] | perplexity 199.19 | epoch 1 | iter 1301 / 1327 | time 3251[s] | perplexity 246.91 | epoch 1 | iter 1321 / 1327 | time 3299[s] | perplexity 234.73 evaluating perplexity ... 209 / 210 valid perplexity: 196.80691468962846 -------------------------------------------------- | epoch 2 | iter 1 / 1327 | time 2[s] | perplexity 291.16 | epoch 2 | iter 21 / 1327 | time 46[s] | perplexity 230.42 | epoch 2 | iter 41 / 1327 | time 90[s] | perplexity 210.98 | epoch 2 | iter 61 / 1327 | time 137[s] | perplexity 195.20 | epoch 2 | iter 81 / 1327 | time 182[s] | perplexity 179.33 | epoch 2 | iter 101 / 1327 | time 234[s] | perplexity 168.96 | epoch 2 | iter 121 / 1327 | time 318[s] | perplexity 179.30 | epoch 2 | iter 141 / 1327 | time 366[s] | perplexity 198.83 | epoch 2 | iter 161 / 1327 | time 411[s] | perplexity 215.86 | epoch 2 | iter 181 / 1327 | time 456[s] | perplexity 222.62 | epoch 2 | iter 201 / 1327 | time 501[s] | perplexity 206.98 | epoch 2 | iter 221 / 1327 | time 547[s] | perplexity 204.71 | epoch 2 | iter 241 / 1327 | time 625[s] | perplexity 197.85 | epoch 2 | iter 261 / 1327 | time 702[s] | perplexity 213.96 | epoch 2 | iter 281 / 1327 | time 755[s] | perplexity 205.79 | epoch 2 | iter 301 / 1327 | time 808[s] | perplexity 186.50 | epoch 2 | iter 321 / 1327 | time 859[s] | perplexity 152.48 | epoch 2 | iter 341 / 1327 | time 907[s] | perplexity 198.20 | epoch 2 | iter 361 / 1327 | time 957[s] | perplexity 215.05 | epoch 2 | iter 381 / 1327 | time 1006[s] | perplexity 170.60 | epoch 2 | iter 401 / 1327 | time 1067[s] | perplexity 193.21 | epoch 2 | iter 421 / 1327 | time 1135[s] | perplexity 176.80 | epoch 2 | iter 441 / 1327 | time 1201[s] | perplexity 180.86 | epoch 2 | iter 461 / 1327 | time 1267[s] | perplexity 182.34 | epoch 2 | iter 481 / 1327 | time 1332[s] | perplexity 174.90 | epoch 2 | iter 501 / 1327 | time 1381[s] | perplexity 192.26 | epoch 2 | iter 521 / 1327 | time 1449[s] | perplexity 189.70 | epoch 2 | iter 541 / 1327 | time 1541[s] | perplexity 202.00 | epoch 2 | iter 561 / 1327 | time 1641[s] | perplexity 171.94 | epoch 2 | iter 581 / 1327 | time 1726[s] | perplexity 158.15 | epoch 2 | iter 601 / 1327 | time 1818[s] | perplexity 215.57 | epoch 2 | iter 621 / 1327 | time 1908[s] | perplexity 201.54 | epoch 2 | iter 641 / 1327 | time 2002[s] | perplexity 185.20 | epoch 2 | iter 661 / 1327 | time 2090[s] | perplexity 174.48 | epoch 2 | iter 681 / 1327 | time 2174[s] | perplexity 146.61 | epoch 2 | iter 701 / 1327 | time 2228[s] | perplexity 171.08 | epoch 2 | iter 721 / 1327 | time 2284[s] | perplexity 176.13 | epoch 2 | iter 741 / 1327 | time 2349[s] | perplexity 150.93 | epoch 2 | iter 761 / 1327 | time 2402[s] | perplexity 149.54 | epoch 2 | iter 781 / 1327 | time 2482[s] | perplexity 149.39 | epoch 2 | iter 801 / 1327 | time 2536[s] | perplexity 168.56 | epoch 2 | iter 821 / 1327 | time 2626[s] | perplexity 161.91 | epoch 2 | iter 841 / 1327 | time 2680[s] | perplexity 165.24 | epoch 2 | iter 861 / 1327 | time 2735[s] | perplexity 161.05 | epoch 2 | iter 881 / 1327 | time 2787[s] | perplexity 149.73 | epoch 2 | iter 901 / 1327 | time 2843[s] | perplexity 188.35 | epoch 2 | iter 921 / 1327 | time 2922[s] | perplexity 165.84 | epoch 2 | iter 941 / 1327 | time 2982[s] | perplexity 168.55 | epoch 2 | iter 961 / 1327 | time 3047[s] | perplexity 186.29 | epoch 2 | iter 981 / 1327 | time 3108[s] | perplexity 174.66 | epoch 2 | iter 1001 / 1327 | time 3162[s] | perplexity 149.82 | epoch 2 | iter 1021 / 1327 | time 3223[s] | perplexity 174.38 | epoch 2 | iter 1041 / 1327 | time 3286[s] | perplexity 157.34 | epoch 2 | iter 1061 / 1327 | time 3348[s] | perplexity 148.87 | epoch 2 | iter 1081 / 1327 | time 3407[s] | perplexity 123.84 | epoch 2 | iter 1101 / 1327 | time 3458[s] | perplexity 137.86 | epoch 2 | iter 1121 / 1327 | time 3511[s] | perplexity 172.72 | epoch 2 | iter 1141 / 1327 | time 3566[s] | perplexity 164.43 | epoch 2 | iter 1161 / 1327 | time 3616[s] | perplexity 147.04 | epoch 2 | iter 1181 / 1327 | time 3666[s] | perplexity 148.45 | epoch 2 | iter 1201 / 1327 | time 3715[s] | perplexity 127.46 | epoch 2 | iter 1221 / 1327 | time 3769[s] | perplexity 125.83 | epoch 2 | iter 1241 / 1327 | time 3822[s] | perplexity 146.89 | epoch 2 | iter 1261 / 1327 | time 3871[s] | perplexity 137.78 | epoch 2 | iter 1281 / 1327 | time 3923[s] | perplexity 140.56 | epoch 2 | iter 1301 / 1327 | time 3973[s] | perplexity 178.40 | epoch 2 | iter 1321 / 1327 | time 4023[s] | perplexity 171.81 evaluating perplexity ... 209 / 210 valid perplexity: 145.6046890226078 -------------------------------------------------- | epoch 3 | iter 1 / 1327 | time 6[s] | perplexity 222.54 | epoch 3 | iter 21 / 1327 | time 124[s] | perplexity 161.39 | epoch 3 | iter 41 / 1327 | time 243[s] | perplexity 152.21 | epoch 3 | iter 61 / 1327 | time 295[s] | perplexity 142.97 | epoch 3 | iter 81 / 1327 | time 344[s] | perplexity 128.21 | epoch 3 | iter 101 / 1327 | time 393[s] | perplexity 122.36 | epoch 3 | iter 121 / 1327 | time 441[s] | perplexity 133.44 | epoch 3 | iter 141 / 1327 | time 492[s] | perplexity 146.84 | epoch 3 | iter 161 / 1327 | time 541[s] | perplexity 162.45 | epoch 3 | iter 181 / 1327 | time 596[s] | perplexity 169.60 | epoch 3 | iter 201 / 1327 | time 647[s] | perplexity 157.57 | epoch 3 | iter 221 / 1327 | time 695[s] | perplexity 155.70 | epoch 3 | iter 241 / 1327 | time 742[s] | perplexity 152.59 | epoch 3 | iter 261 / 1327 | time 797[s] | perplexity 163.07 | epoch 3 | iter 281 / 1327 | time 846[s] | perplexity 156.43 | epoch 3 | iter 301 / 1327 | time 906[s] | perplexity 138.56 | epoch 3 | iter 321 / 1327 | time 970[s] | perplexity 112.04 | epoch 3 | iter 341 / 1327 | time 1036[s] | perplexity 151.57 | epoch 3 | iter 361 / 1327 | time 1098[s] | perplexity 165.21 | epoch 3 | iter 381 / 1327 | time 1159[s] | perplexity 131.16 | epoch 3 | iter 401 / 1327 | time 1215[s] | perplexity 149.15 | epoch 3 | iter 421 / 1327 | time 1264[s] | perplexity 130.95 | epoch 3 | iter 441 / 1327 | time 1322[s] | perplexity 140.28 | epoch 3 | iter 461 / 1327 | time 1375[s] | perplexity 138.78 | epoch 3 | iter 481 / 1327 | time 1429[s] | perplexity 134.81 | epoch 3 | iter 501 / 1327 | time 1480[s] | perplexity 149.98 | epoch 3 | iter 521 / 1327 | time 1534[s] | perplexity 152.35 | epoch 3 | iter 541 / 1327 | time 1583[s] | perplexity 157.25 | epoch 3 | iter 561 / 1327 | time 1633[s] | perplexity 134.19 | epoch 3 | iter 581 / 1327 | time 1682[s] | perplexity 123.06 | epoch 3 | iter 601 / 1327 | time 1735[s] | perplexity 169.27 | epoch 3 | iter 621 / 1327 | time 1788[s] | perplexity 160.02 | epoch 3 | iter 641 / 1327 | time 1838[s] | perplexity 146.95 | epoch 3 | iter 661 / 1327 | time 1888[s] | perplexity 136.56 | epoch 3 | iter 681 / 1327 | time 1939[s] | perplexity 117.16 | epoch 3 | iter 701 / 1327 | time 1991[s] | perplexity 136.33 | epoch 3 | iter 721 / 1327 | time 2043[s] | perplexity 139.97 | epoch 3 | iter 741 / 1327 | time 2094[s] | perplexity 121.00 | epoch 3 | iter 761 / 1327 | time 2154[s] | perplexity 115.03 | epoch 3 | iter 781 / 1327 | time 2216[s] | perplexity 122.38 | epoch 3 | iter 801 / 1327 | time 2270[s] | perplexity 135.23 | epoch 3 | iter 821 / 1327 | time 2326[s] | perplexity 133.33 | epoch 3 | iter 841 / 1327 | time 2405[s] | perplexity 134.35 | epoch 3 | iter 861 / 1327 | time 2457[s] | perplexity 130.02