Sequence-to-Sequence (Seq2Seq) モデルは、系列を入力として系列を出力するモデルです。
入力系列をRNNで固定長のベクトルに変換(= Encode)し、そのベクトルを用いて系列を出力(= Decode)することから、Encoder-Decoder モデルとも呼ばれます。
RNNの代わりにLSTMやGRUでも可能です。
機械翻訳のほか、文書要約や対話生成にも使われます。
今回は機械翻訳を例にとって解説していきます。
# このページの情報で、特定のバージョンじゃないとだめと書いてあった。
# https://github.com/buildout/buildout.wheel/issues/18
!pip3 install wheel==0.34.1
Collecting wheel==0.34.1
Downloading https://files.pythonhosted.org/packages/81/44/db78754a73d9a88c5bd1bb692b40004410970e88aa0c5dff20b57f231505/wheel-0.34.1-py2.py3-none-any.whl
ERROR: tensorflow 2.5.0 has requirement wheel~=0.35, but you'll have wheel 0.34.1 which is incompatible.
Installing collected packages: wheel
Found existing installation: wheel 0.36.2
Uninstalling wheel-0.36.2:
Successfully uninstalled wheel-0.36.2
Successfully installed wheel-0.34.1
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'
!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision
import torch
print(torch.__version__)
print(torch.cuda.is_available())
ERROR: HTTP error 403 while getting http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl ERROR: Could not install requirement torch==0.4.0 from http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl because of error 403 Client Error: Forbidden for url: http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl ERROR: Could not install requirement torch==0.4.0 from http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl because of HTTP error 403 Client Error: Forbidden for url: http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl for URL http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl 1.9.0+cu102 True
! wget https://www.dropbox.com/s/9narw5x4uizmehh/utils.py
! mkdir images data
# data取得
! wget https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en -P data/
! wget https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja -P data/
! wget https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en -P data/
! wget https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja -P data/
! wget https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en -P data/
! wget https://www.dropbox.com/s/ak53qirssci6f1j/train.ja -P data/
--2021-07-18 06:34:36-- https://www.dropbox.com/s/9narw5x4uizmehh/utils.py Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/9narw5x4uizmehh/utils.py [following] --2021-07-18 06:34:36-- https://www.dropbox.com/s/raw/9narw5x4uizmehh/utils.py Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com/cd/0/inline/BSgUmGGCHiSqU-Aw1pqfKOzPawKACdyROQJrqEsxM4-f2H4Gzrz7YDwXXaAliCXeVNoddmtHd5QRGb2HA7V4GS8dHSdUao3EEMxm_3A-SsbUgwbId4NORcApbJpJ8upvnIsWwa2PJRnNoZumkX1ssAF-/file# [following] --2021-07-18 06:34:36-- https://uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com/cd/0/inline/BSgUmGGCHiSqU-Aw1pqfKOzPawKACdyROQJrqEsxM4-f2H4Gzrz7YDwXXaAliCXeVNoddmtHd5QRGb2HA7V4GS8dHSdUao3EEMxm_3A-SsbUgwbId4NORcApbJpJ8upvnIsWwa2PJRnNoZumkX1ssAF-/file Resolving uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com (uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f Connecting to uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com (uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 949 [text/plain] Saving to: ‘utils.py’ utils.py 100%[===================>] 949 --.-KB/s in 0s 2021-07-18 06:34:36 (107 MB/s) - ‘utils.py’ saved [949/949] --2021-07-18 06:34:36-- https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/o4kyc52a8we25wy/dev.en [following] --2021-07-18 06:34:37-- https://www.dropbox.com/s/raw/o4kyc52a8we25wy/dev.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com/cd/0/inline/BSiwAsgw6yNys2vumtK-8EobSB5qkLzkuaxKBAcWUIop_KIUtGQZJVrBojvaz8LTWNj7MiLuBJxzMXg_6tQCs8KOTuBQXIi7ulv0hXbSwh2Vn_hSEBn_XmWBfK6ZBnrRhBkhbY0EyhNAEz3UlKPFwuRL/file# [following] --2021-07-18 06:34:37-- https://ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com/cd/0/inline/BSiwAsgw6yNys2vumtK-8EobSB5qkLzkuaxKBAcWUIop_KIUtGQZJVrBojvaz8LTWNj7MiLuBJxzMXg_6tQCs8KOTuBQXIi7ulv0hXbSwh2Vn_hSEBn_XmWBfK6ZBnrRhBkhbY0EyhNAEz3UlKPFwuRL/file Resolving ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com (ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f Connecting to ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com (ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 17054 (17K) [text/plain] Saving to: ‘data/dev.en’ dev.en 100%[===================>] 16.65K --.-KB/s in 0s 2021-07-18 06:34:37 (52.7 MB/s) - ‘data/dev.en’ saved [17054/17054] --2021-07-18 06:34:37-- https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/kdgskm5hzg6znuc/dev.ja [following] --2021-07-18 06:34:37-- https://www.dropbox.com/s/raw/kdgskm5hzg6znuc/dev.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com/cd/0/inline/BSjTOog1jPfDXFmgmhf_rdryp5iNiwzfoV9hSJFKtlrfA3Z81FYzrjoaSSqzvLXjAPGVzSVBiLyW_8OikG6gE37sIrowFV-JBT3ve4tbkc_U8n94BJSGQReAq7kafyd3qw-OUeh5JP1ZEKgt_96CdCQb/file# [following] --2021-07-18 06:34:38-- https://ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com/cd/0/inline/BSjTOog1jPfDXFmgmhf_rdryp5iNiwzfoV9hSJFKtlrfA3Z81FYzrjoaSSqzvLXjAPGVzSVBiLyW_8OikG6gE37sIrowFV-JBT3ve4tbkc_U8n94BJSGQReAq7kafyd3qw-OUeh5JP1ZEKgt_96CdCQb/file Resolving ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com (ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f Connecting to ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com (ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 27781 (27K) [text/plain] Saving to: ‘data/dev.ja’ dev.ja 100%[===================>] 27.13K --.-KB/s in 0.02s 2021-07-18 06:34:38 (1.48 MB/s) - ‘data/dev.ja’ saved [27781/27781] --2021-07-18 06:34:38-- https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/gyyx4gohv9v65uh/test.en [following] --2021-07-18 06:34:38-- https://www.dropbox.com/s/raw/gyyx4gohv9v65uh/test.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com/cd/0/inline/BSiuGhflXLPPjGibDJfLwUmgmpl547WyWBIaaJUuma-2AwZQ9T4Ds3IRYWyXxXWO1n01MQTthDyWKwAbtzi_q6n6vug-T8cIN_VaxvaHhvbPUJkLWVSGWdSSJCYwfzvLPm7pj1fRkYGh2mPyIbpQpH-m/file# [following] --2021-07-18 06:34:38-- https://uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com/cd/0/inline/BSiuGhflXLPPjGibDJfLwUmgmpl547WyWBIaaJUuma-2AwZQ9T4Ds3IRYWyXxXWO1n01MQTthDyWKwAbtzi_q6n6vug-T8cIN_VaxvaHhvbPUJkLWVSGWdSSJCYwfzvLPm7pj1fRkYGh2mPyIbpQpH-m/file Resolving uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com (uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f Connecting to uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com (uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 17301 (17K) [text/plain] Saving to: ‘data/test.en’ test.en 100%[===================>] 16.90K --.-KB/s in 0.002s 2021-07-18 06:34:39 (10.9 MB/s) - ‘data/test.en’ saved [17301/17301] --2021-07-18 06:34:39-- https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/hotxwbgoe2n013k/test.ja [following] --2021-07-18 06:34:40-- https://www.dropbox.com/s/raw/hotxwbgoe2n013k/test.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com/cd/0/inline/BSgRc-0Hj7l_EK3_qOouZd86iiU710RlBc0VQjdSxB6PCfvYWaqv9XTbSd8LaxNRJE5qksADXFKoRUKxoeImMOldFqyCqyoMv4AWrgwsIBEIMZFobTY_YP6lu1fu44bkyxrRWz80kwF2-eEgHbQlz08s/file# [following] --2021-07-18 06:34:40-- https://uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com/cd/0/inline/BSgRc-0Hj7l_EK3_qOouZd86iiU710RlBc0VQjdSxB6PCfvYWaqv9XTbSd8LaxNRJE5qksADXFKoRUKxoeImMOldFqyCqyoMv4AWrgwsIBEIMZFobTY_YP6lu1fu44bkyxrRWz80kwF2-eEgHbQlz08s/file Resolving uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com (uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f Connecting to uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com (uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 27793 (27K) [text/plain] Saving to: ‘data/test.ja’ test.ja 100%[===================>] 27.14K --.-KB/s in 0.02s 2021-07-18 06:34:40 (1.67 MB/s) - ‘data/test.ja’ saved [27793/27793] --2021-07-18 06:34:40-- https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/5lsftkmb20ay9e1/train.en [following] --2021-07-18 06:34:40-- https://www.dropbox.com/s/raw/5lsftkmb20ay9e1/train.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com/cd/0/inline/BSjkSPW1zz5m31_fTvKaS5-65LtHWcp93afPXPX2Ez78RpFhVnw8WNq8EhlDRReZ9h8JNdAdgL0snDFXldEwqqe_L05AXuAOLcWuZek9Tlh7ajuCtSzKLOZwmzhpWDyhVX_zS9tcIyq2U_SAo2loDhDy/file# [following] --2021-07-18 06:34:41-- https://ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com/cd/0/inline/BSjkSPW1zz5m31_fTvKaS5-65LtHWcp93afPXPX2Ez78RpFhVnw8WNq8EhlDRReZ9h8JNdAdgL0snDFXldEwqqe_L05AXuAOLcWuZek9Tlh7ajuCtSzKLOZwmzhpWDyhVX_zS9tcIyq2U_SAo2loDhDy/file Resolving ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com (ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f Connecting to ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com (ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1701356 (1.6M) [text/plain] Saving to: ‘data/train.en’ train.en 100%[===================>] 1.62M --.-KB/s in 0.1s 2021-07-18 06:34:41 (13.6 MB/s) - ‘data/train.en’ saved [1701356/1701356] --2021-07-18 06:34:41-- https://www.dropbox.com/s/ak53qirssci6f1j/train.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112 Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/ak53qirssci6f1j/train.ja [following] --2021-07-18 06:34:41-- https://www.dropbox.com/s/raw/ak53qirssci6f1j/train.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com/cd/0/inline/BSgHUx4UNd_T9p8lYt2u9EB0Xyf5HHEGfnn3hfLGNhtb3bNM7pZtQwlYVRgHnBsx-lk_k0gV85uYgwlBDU2_06mQz43BzAEDOnD_CYw4XdlsMAfI_7FKWE7MXIRNRK6v-O1GdySSK2J9acJLT53yduod/file# [following] --2021-07-18 06:34:41-- https://uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com/cd/0/inline/BSgHUx4UNd_T9p8lYt2u9EB0Xyf5HHEGfnn3hfLGNhtb3bNM7pZtQwlYVRgHnBsx-lk_k0gV85uYgwlBDU2_06mQz43BzAEDOnD_CYw4XdlsMAfI_7FKWE7MXIRNRK6v-O1GdySSK2J9acJLT53yduod/file Resolving uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com (uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f Connecting to uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com (uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com)|162.125.1.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 2784447 (2.7M) [text/plain] Saving to: ‘data/train.ja’ train.ja 100%[===================>] 2.66M 6.84MB/s in 0.4s 2021-07-18 06:34:42 (6.84 MB/s) - ‘data/train.ja’ saved [2784447/2784447]
! ls data
dev.en dev.ja test.en test.ja train.en train.ja
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from nltk import bleu_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from utils import Vocab
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
random_state = 42
print(torch.__version__)
1.9.0+cu102
英語-日本語の対訳コーパスである、Tanaka Corpus ( http://www.edrdg.org/wiki/index.php/Tanaka_Corpus )を使います。
今回はそのうちの一部分を取り出したsmall_parallel_enja: 50k En/Ja Parallel Corpus for Testing SMT Methods ( https://github.com/odashi/small_parallel_enja )を使用します。
train.enとtrain.jaの中身を見てみましょう。
! head -10 data/train.en
i can 't tell who will arrive first . many animals have been destroyed by men . i 'm in the tennis club . emi looks happy . please bear this fact in mind . she takes care of my children . we want to be international . you ought not to break your promise . when you cross the street , watch out for cars . i have nothing to live for .
! head -10 ./data/train.ja
誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。 多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。 私 は テニス 部員 で す 。 エミ は 幸せ そう に 見え ま す 。 この 事実 を 心 に 留め て お い て 下さ い 。 彼女 は 私 たち の 世話 を し て くれ る 。 私 達 は 国際 人 に な り た い と 思 い ま す 。 約束 を 破 る べ き で は あ り ま せ ん 。 道路 を 横切 る とき は 車 に 注意 し なさ い 。 私 に は 生き 甲斐 が な い 。
それぞれの文章が英語-日本語で対応しているのがわかります。
def load_data(file_path):
# テキストファイルからデータを読み込むメソッド
data = []
for line in open(file_path, encoding='utf-8'):
words = line.strip().split() # スペースで単語を分割
data.append(words)
return data
train_X = load_data('./data/train.en')
train_Y = load_data('./data/train.ja')
# 訓練データと検証データに分割
train_X, valid_X, train_Y, valid_Y = train_test_split(train_X, train_Y, test_size=0.2, random_state=random_state)
この時点で入力と教師データは以下のようになっています
print('train data', train_X[0])
print('valid data', valid_X[0])
train data ['where', 'shall', 'we', 'eat', 'tonight', '?'] valid data ['you', 'may', 'extend', 'your', 'stay', 'in', 'tokyo', '.']
データセットに登場する各単語にIDを割り振る
# まず特殊トークンを定義しておく
PAD_TOKEN = '<PAD>' # バッチ処理の際に、短い系列の末尾を埋めるために使う (Padding)
BOS_TOKEN = '<S>' # 系列の始まりを表す (Beggining of sentence)
EOS_TOKEN = '</S>' # 系列の終わりを表す (End of sentence)
UNK_TOKEN = '<UNK>' # 語彙に存在しない単語を表す (Unknown)
PAD = 0
BOS = 1
EOS = 2
UNK = 3
MIN_COUNT = 2 # 語彙に含める単語の最低出現回数 再提出現回数に満たない単語はUNKに置き換えられる
# 単語をIDに変換する辞書の初期値を設定
word2id = {
PAD_TOKEN: PAD,
BOS_TOKEN: BOS,
EOS_TOKEN: EOS,
UNK_TOKEN: UNK,
}
# 単語辞書を作成
vocab_X = Vocab(word2id=word2id)
vocab_Y = Vocab(word2id=word2id)
vocab_X.build_vocab(train_X, min_count=MIN_COUNT)
vocab_Y.build_vocab(train_Y, min_count=MIN_COUNT)
vocab_size_X = len(vocab_X.id2word)
vocab_size_Y = len(vocab_Y.id2word)
print('入力言語の語彙数:', vocab_size_X)
print('出力言語の語彙数:', vocab_size_Y)
入力言語の語彙数: 3725 出力言語の語彙数: 4405
まずはモデルが文章を認識できるように、文章を単語IDのリストに変換します
def sentence_to_ids(vocab, sentence):
# 単語(str)のリストをID(int)のリストに変換する関数
ids = [vocab.word2id.get(word, UNK) for word in sentence]
ids += [EOS] # EOSを加える
return ids
train_X = [sentence_to_ids(vocab_X, sentence) for sentence in train_X]
train_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in train_Y]
valid_X = [sentence_to_ids(vocab_X, sentence) for sentence in valid_X]
valid_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in valid_Y]
この時点で入力と教師データは以下のようになっている
print('train data', train_X[0])
print('valid data', valid_X[0])
train data [132, 321, 28, 290, 367, 12, 2] valid data [8, 93, 3532, 36, 236, 13, 284, 4, 2]
データセットからバッチを取得するデータローダーを定義します
<PAD>
など)でパディングし、バッチ内の系列の長さを最長のものに合わせるdef pad_seq(seq, max_length):
# 系列(seq)が指定の文長(max_length)になるように末尾をパディングする
res = seq + [PAD for i in range(max_length - len(seq))]
return res
class DataLoader(object):
def __init__(self, X, Y, batch_size, shuffle=False):
"""
:param X: list, 入力言語の文章(単語IDのリスト)のリスト
:param Y: list, 出力言語の文章(単語IDのリスト)のリスト
:param batch_size: int, バッチサイズ
:param shuffle: bool, サンプルの順番をシャッフルするか否か
"""
self.data = list(zip(X, Y))
self.batch_size = batch_size
self.shuffle = shuffle
self.start_index = 0
self.reset()
def reset(self):
if self.shuffle: # サンプルの順番をシャッフルする
self.data = shuffle(self.data, random_state=random_state)
self.start_index = 0 # ポインタの位置を初期化する
def __iter__(self):
return self
def __next__(self):
# ポインタが最後まで到達したら初期化する
if self.start_index >= len(self.data):
self.reset()
raise StopIteration()
# バッチを取得
seqs_X, seqs_Y = zip(*self.data[self.start_index:self.start_index+self.batch_size])
# 入力系列seqs_Xの文章の長さ順(降順)に系列ペアをソートする
seq_pairs = sorted(zip(seqs_X, seqs_Y), key=lambda p: len(p[0]), reverse=True)
seqs_X, seqs_Y = zip(*seq_pairs)
# 短い系列の末尾をパディングする
lengths_X = [len(s) for s in seqs_X] # 後述のEncoderのpack_padded_sequenceでも用いる
lengths_Y = [len(s) for s in seqs_Y]
max_length_X = max(lengths_X)
max_length_Y = max(lengths_Y)
padded_X = [pad_seq(s, max_length_X) for s in seqs_X]
padded_Y = [pad_seq(s, max_length_Y) for s in seqs_Y]
# tensorに変換し、転置する
batch_X = torch.tensor(padded_X, dtype=torch.long, device=device).transpose(0, 1)
batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)
# ポインタを更新する
self.start_index += self.batch_size
return batch_X, batch_Y, lengths_X
EncoderとDecoderのRNNを定義します。
PyTorchのRNNでは、可変長の系列のバッチを効率よく計算できるように系列を表現するPackedSequence
というクラスを用いることができます。
入力バッチのテンソルをこのPackedSequence
のインスタンスに変換してからRNNに入力することで、パディング部分の計算を省略することができるため、効率的な計算が可能になります。
PackedSequence
を作成するには、まず、系列長の異なるバッチに対してパディングを行なってください。
ここで、パディングを行う前に各サンプルの系列長(lengths
)を保存しておきます。
# 系列長がそれぞれ4,3,2の3つのサンプルからなるバッチを作成
batch = [[1,2,3,4], [5,6,7], [8,9]]
lengths = [len(sample) for sample in batch]
print('各サンプルの系列長:', lengths)
print()
# 最大系列長に合うように各サンプルをpadding
_max_length = max(lengths)
padded = torch.tensor([pad_seq(sample, _max_length) for sample in batch])
print('paddingされたテンソル:\n', padded)
padded = padded.transpose(0,1) # (max_length, batch_size)に転置
print('padding & 転置されたテンソル:\n', padded)
print('padding & 転置されたテンソルのサイズ:\n', padded.size())
print()
各サンプルの系列長: [4, 3, 2] paddingされたテンソル: tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) padding & 転置されたテンソル: tensor([[1, 5, 8], [2, 6, 9], [3, 7, 0], [4, 0, 0]]) padding & 転置されたテンソルのサイズ: torch.Size([4, 3])
次に、パディングを行ったテンソル(padded
)と各サンプルの元々の系列長(lengths
)をtorch.nn.utils.rnn.pack_padded_sequence
という関数に与えると、
data
とbatch_sizes
という要素を持ったPackedSequence
のインスタンス(packed
)が作成できます。
data
: テンソルのPAD
以外の値のみを保有するベクトルbatch_sizes
: 各時刻で計算が必要な(=PAD
に到達していない)バッチの数を表すベクトル# PackedSequenceに変換(テンソルをRNNに入力する前に適用する)
packed = pack_padded_sequence(padded, lengths=lengths) # 各サンプルの系列長も与える
print('PackedSequenceのインスタンス:\n', packed) # テンソルのPAD以外の値(data)と各時刻で計算が必要な(=PADに到達していない)バッチの数(batch_sizes)を有するインスタンス
print()
PackedSequenceのインスタンス: PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=None, unsorted_indices=None)
こうして得られたPackedSequence
のインスタンスをRNNに入力します。(ここでは省略)
RNNから出力されたテンソルはPackedSeauence
のインスタンスのままなので、後段の計算につなぐためにtorch.nn.utils.rnn.pad_packed_sequence
の関数によって通常のテンソルに戻します。
# PackedSequenceのインスタンスをRNNに入力する(ここでは省略)
output = packed
# テンソルに戻す(RNNの出力に対して適用する)
output, _length = pad_packed_sequence(output) # PADを含む元のテンソルと各サンプルの系列長を返す
print('PADを含む元のテンソル:\n', output)
print('各サンプルの系列長:', _length)
PADを含む元のテンソル: tensor([[1, 5, 8], [2, 6, 9], [3, 7, 0], [4, 0, 0]]) 各サンプルの系列長: tensor([4, 3, 2])
今回はEncoder側でバッチを処理する際に、pack_padded_sequence
関数によってtensorをPackedSequence
に変換し、処理を終えた後にpad_packed_sequence
関数によってtensorに戻すという処理を行います。
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size):
"""
:param input_size: int, 入力言語の語彙数
:param hidden_size: int, 隠れ層のユニット数
"""
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, seqs, input_lengths, hidden=None):
"""
:param seqs: tensor, 入力のバッチ, size=(max_length, batch_size)
:param input_lengths: 入力のバッチの各サンプルの文長
:param hidden: tensor, 隠れ状態の初期値, Noneの場合は0で初期化される
:return output: tensor, Encoderの出力, size=(max_length, batch_size, hidden_size)
:return hidden: tensor, Encoderの隠れ状態, size=(1, batch_size, hidden_size)
"""
emb = self.embedding(seqs) # seqsはパディング済み
packed = pack_padded_sequence(emb, input_lengths) # PackedSequenceオブジェクトに変換
output, hidden = self.gru(packed, hidden)
output, _ = pad_packed_sequence(output)
return output, hidden
今回はDecoder側ではパディング等行わないので、通常のtensorのままRNNに入力して問題ありません。
class Decoder(nn.Module):
def __init__(self, hidden_size, output_size):
"""
:param hidden_size: int, 隠れ層のユニット数
:param output_size: int, 出力言語の語彙数
:param dropout: float, ドロップアウト率
"""
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)
self.gru = nn.GRU(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, seqs, hidden):
"""
:param seqs: tensor, 入力のバッチ, size=(1, batch_size)
:param hidden: tensor, 隠れ状態の初期値, Noneの場合は0で初期化される
:return output: tensor, Decoderの出力, size=(1, batch_size, output_size)
:return hidden: tensor, Decoderの隠れ状態, size=(1, batch_size, hidden_size)
"""
emb = self.embedding(seqs)
output, hidden = self.gru(emb, hidden)
output = self.out(output)
return output, hidden
上で定義したEncoderとDecoderを用いた、一連の処理をまとめるEncoderDecoderのクラスを定義します。
ここで、Decoder側の処理で注意する点があります。
RNNでは、時刻tの出力を時刻t+1の入力とすることができるが、この方法でDecoderを学習させると連鎖的に誤差が大きくなっていき、学習が不安定になったり収束が遅くなったりする問題が発生します。
この問題への対策としてTeacher Forcingというテクニックがあります。 これは、訓練時にはDecoder側の入力に、ターゲット系列(参照訳)をそのまま使うというものです。 これにより学習が安定し、収束が早くなるというメリットがありますが、逆に評価時は前の時刻にDecoderが生成したものが使われるため、学習時と分布が異なってしまうというデメリットもあります。
Teacher Forcingの拡張として、ターゲット系列を入力とするか生成された結果を入力とするかを確率的にサンプリングするScheduled Samplingという手法があります。
ここではScheduled Samplingを採用し、一定の確率に基づいてターゲット系列を入力とするか生成された結果を入力とするかを切り替えられるようにクラスを定義しておきます。
class EncoderDecoder(nn.Module):
"""EncoderとDecoderの処理をまとめる"""
def __init__(self, input_size, output_size, hidden_size):
"""
:param input_size: int, 入力言語の語彙数
:param output_size: int, 出力言語の語彙数
:param hidden_size: int, 隠れ層のユニット数
"""
super(EncoderDecoder, self).__init__()
self.encoder = Encoder(input_size, hidden_size)
self.decoder = Decoder(hidden_size, output_size)
def forward(self, batch_X, lengths_X, max_length, batch_Y=None, use_teacher_forcing=False):
"""
:param batch_X: tensor, 入力系列のバッチ, size=(max_length, batch_size)
:param lengths_X: list, 入力系列のバッチ内の各サンプルの文長
:param max_length: int, Decoderの最大文長
:param batch_Y: tensor, Decoderで用いるターゲット系列
:param use_teacher_forcing: Decoderでターゲット系列を入力とするフラグ
:return decoder_outputs: tensor, Decoderの出力,
size=(max_length, batch_size, self.decoder.output_size)
"""
# encoderに系列を入力(複数時刻をまとめて処理)
_, encoder_hidden = self.encoder(batch_X, lengths_X)
_batch_size = batch_X.size(1)
# decoderの入力と隠れ層の初期状態を定義
decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device) # 最初の入力にはBOSを使用する
decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)
decoder_hidden = encoder_hidden # Encoderの最終隠れ状態を取得
# decoderの出力のホルダーを定義
decoder_outputs = torch.zeros(max_length, _batch_size, self.decoder.output_size, device=device) # max_length分の固定長
# 各時刻ごとに処理
for t in range(max_length):
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
decoder_outputs[t] = decoder_output
# 次の時刻のdecoderの入力を決定
if use_teacher_forcing and batch_Y is not None: # teacher forceの場合、ターゲット系列を用いる
decoder_input = batch_Y[t].unsqueeze(0)
else: # teacher forceでない場合、自身の出力を用いる
decoder_input = decoder_output.max(-1)[1]
return decoder_outputs
mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD) # PADを無視する
def masked_cross_entropy(logits, target):
logits_flat = logits.view(-1, logits.size(-1)) # (max_seq_len * batch_size, output_size)
target_flat = target.view(-1) # (max_seq_len * batch_size, 1)
return mce(logits_flat, target_flat)
/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret))
# ハイパーパラメータの設定
num_epochs = 10
batch_size = 64
lr = 1e-3 # 学習率
teacher_forcing_rate = 0.2 # Teacher Forcingを行う確率
ckpt_path = 'model.pth' # 学習済みのモデルを保存するパス
model_args = {
'input_size': vocab_size_X,
'output_size': vocab_size_Y,
'hidden_size': 256,
}
# データローダを定義
train_dataloader = DataLoader(train_X, train_Y, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_X, valid_Y, batch_size=batch_size, shuffle=False)
# モデルとOptimizerを定義
model = EncoderDecoder(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
実際に損失関数を計算する関数を定義します。
def compute_loss(batch_X, batch_Y, lengths_X, model, optimizer=None, is_train=True):
# 損失を計算する関数
model.train(is_train) # train/evalモードの切替え
# 一定確率でTeacher Forcingを行う
use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)
max_length = batch_Y.size(0)
# 推論
pred_Y = model(batch_X, lengths_X, max_length, batch_Y, use_teacher_forcing)
# 損失関数を計算
loss = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())
if is_train: # 訓練時はパラメータを更新
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_Y = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()
pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()
return loss.item(), batch_Y, pred
ここで、Loss以外に、学習の進捗を確認するためにモデルの性能を評価する指標として、BLEUを計算します。
BLEUは機械翻訳の分野において最も一般的な自動評価基準の一つで、予め用意した複数の参照訳と、機械翻訳モデルが出力した訳のn-gramのマッチ率に基づく指標です。
NLTK (Natural Language Tool Kit) という自然言語処理で用いられるライブラリを用いて簡単に計算することができます。
def calc_bleu(refs, hyps):
"""
BLEUスコアを計算する関数
:param refs: list, 参照訳。単語のリストのリスト (例: [['I', 'have', 'a', 'pen'], ...])
:param hyps: list, モデルの生成した訳。単語のリストのリスト (例: ['I', 'have', 'a', 'pen'])
:return: float, BLEUスコア(0~100)
"""
refs = [[ref[:ref.index(EOS)]] for ref in refs] # EOSは評価しないで良いので切り捨てる, refsのほうは複数なのでlistが一個多くかかっている
hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps]
return 100 * bleu_score.corpus_bleu(refs, hyps)
それではモデルの訓練を行います。
# 訓練
best_valid_bleu = 0.
for epoch in range(1, num_epochs+1):
train_loss = 0.
train_refs = []
train_hyps = []
valid_loss = 0.
valid_refs = []
valid_hyps = []
# train
for batch in train_dataloader:
batch_X, batch_Y, lengths_X = batch
loss, gold, pred = compute_loss(
batch_X, batch_Y, lengths_X, model, optimizer,
is_train=True
)
train_loss += loss
train_refs += gold
train_hyps += pred
# valid
for batch in valid_dataloader:
batch_X, batch_Y, lengths_X = batch
loss, gold, pred = compute_loss(
batch_X, batch_Y, lengths_X, model,
is_train=False
)
valid_loss += loss
valid_refs += gold
valid_hyps += pred
# 損失をサンプル数で割って正規化
train_loss = np.sum(train_loss) / len(train_dataloader.data)
valid_loss = np.sum(valid_loss) / len(valid_dataloader.data)
# BLEUを計算
train_bleu = calc_bleu(train_refs, train_hyps)
valid_bleu = calc_bleu(valid_refs, valid_hyps)
# validationデータでBLEUが改善した場合にはモデルを保存
if valid_bleu > best_valid_bleu:
ckpt = model.state_dict()
torch.save(ckpt, ckpt_path)
best_valid_bleu = valid_bleu
print('Epoch {}: train_loss: {:5.2f} train_bleu: {:2.2f} valid_loss: {:5.2f} valid_bleu: {:2.2f}'.format(
epoch, train_loss, train_bleu, valid_loss, valid_bleu))
print('-'*80)
Epoch 1: train_loss: 52.44 train_bleu: 3.30 valid_loss: 48.78 valid_bleu: 5.10 -------------------------------------------------------------------------------- Epoch 2: train_loss: 44.48 train_bleu: 7.57 valid_loss: 44.77 valid_bleu: 8.37 -------------------------------------------------------------------------------- Epoch 3: train_loss: 40.05 train_bleu: 11.49 valid_loss: 41.95 valid_bleu: 8.68 -------------------------------------------------------------------------------- Epoch 4: train_loss: 37.40 train_bleu: 14.04 valid_loss: 41.00 valid_bleu: 13.31 -------------------------------------------------------------------------------- Epoch 5: train_loss: 34.79 train_bleu: 17.00 valid_loss: 40.30 valid_bleu: 14.62 -------------------------------------------------------------------------------- Epoch 6: train_loss: 32.96 train_bleu: 19.18 valid_loss: 39.93 valid_bleu: 15.41 -------------------------------------------------------------------------------- Epoch 7: train_loss: 31.71 train_bleu: 20.88 valid_loss: 39.90 valid_bleu: 16.35 -------------------------------------------------------------------------------- Epoch 8: train_loss: 30.40 train_bleu: 22.62 valid_loss: 40.41 valid_bleu: 17.56 -------------------------------------------------------------------------------- Epoch 9: train_loss: 29.20 train_bleu: 24.48 valid_loss: 40.64 valid_bleu: 18.55 -------------------------------------------------------------------------------- Epoch 10: train_loss: 27.63 train_bleu: 27.09 valid_loss: 40.98 valid_bleu: 19.21 --------------------------------------------------------------------------------
# 学習済みモデルの読み込み
ckpt = torch.load(ckpt_path) # cpuで処理する場合はmap_locationで指定する必要があります。
model.load_state_dict(ckpt)
model.eval()
EncoderDecoder( (encoder): Encoder( (embedding): Embedding(3725, 256, padding_idx=0) (gru): GRU(256, 256) ) (decoder): Decoder( (embedding): Embedding(4405, 256, padding_idx=0) (gru): GRU(256, 256) (out): Linear(in_features=256, out_features=4405, bias=True) ) )
def ids_to_sentence(vocab, ids):
# IDのリストを単語のリストに変換する
return [vocab.id2word[_id] for _id in ids]
def trim_eos(ids):
# IDのリストからEOS以降の単語を除外する
if EOS in ids:
return ids[:ids.index(EOS)]
else:
return ids
# テストデータの読み込み
test_X = load_data('./data/dev.en')
test_Y = load_data('./data/dev.ja')
test_X = [sentence_to_ids(vocab_X, sentence) for sentence in test_X]
test_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in test_Y]
test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)
# 生成
batch_X, batch_Y, lengths_X = next(test_dataloader)
sentence_X = ' '.join(ids_to_sentence(vocab_X, batch_X.data.cpu().numpy()[:-1, 0]))
sentence_Y = ' '.join(ids_to_sentence(vocab_Y, batch_Y.data.cpu().numpy()[:-1, 0]))
print('src: {}'.format(sentence_X))
print('tgt: {}'.format(sentence_Y))
output = model(batch_X, lengths_X, max_length=20)
output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()
output_sentence = ' '.join(ids_to_sentence(vocab_Y, trim_eos(output)))
output_sentence_without_trim = ' '.join(ids_to_sentence(vocab_Y, output))
print('out: {}'.format(output_sentence))
print('without trim: {}'.format(output_sentence_without_trim))
src: we went to boston , where we stayed a week . tgt: 私 たち は ボストン に 行 き 、 そこ に 一 週間 滞在 し た 。 out: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た without trim: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た </S> </S> </S> </S>
# BLEUの計算
test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)
refs_list = []
hyp_list = []
for batch in test_dataloader:
batch_X, batch_Y, lengths_X = batch
pred_Y = model(batch_X, lengths_X, max_length=20)
pred = pred_Y.max(dim=-1)[1].view(-1).data.cpu().tolist()
refs = batch_Y.view(-1).data.cpu().tolist()
refs_list.append(refs)
hyp_list.append(pred)
bleu = calc_bleu(refs_list, hyp_list)
print(bleu)
19.696735593854097
テストデータに対して新たな文を生成する際、これまでは各時刻で最も確率の高い単語を正解として採用し、次のステップでの入力として使っていました。 ただ、本当にやりたいのは、文全体の尤度が最も高くなるような文を生成することです。そのため、ただ近視眼的に確率の高い単語を採用していくより、もう少し大局的に評価していく必要があります。
Beam Searchでは、各時刻において一定の数Kのそれまでのスコア(対数尤度など)の高い文を保持しながら選択を行っていきます。
図はSlack上のものを参照してください。
class BeamEncoderDecoder(EncoderDecoder):
"""
Beam Searchでdecodeを行うためのクラス
"""
def __init__(self, input_size, output_size, hidden_size, beam_size=4):
"""
:param input_size: int, 入力言語の語彙数
:param output_size: int, 出力言語の語彙数
:param hidden_size: int, 隠れ層のユニット数
:param beam_size: int, ビーム数
"""
super(BeamEncoderDecoder, self).__init__(input_size, output_size, hidden_size)
self.beam_size = beam_size
def forward(self, batch_X, lengths_X, max_length):
"""
:param batch_X: tensor, 入力系列のバッチ, size=(max_length, batch_size)
:param lengths_X: list, 入力系列のバッチ内の各サンプルの文長
:param max_length: int, Decoderの最大文長
:return decoder_outputs: list, 各ビームのDecoderの出力
:return finished_scores: list of float, 各ビームのスコア
"""
_, encoder_hidden = self.encoder(batch_X, lengths_X)
# decoderの入力と隠れ層の初期状態を定義
decoder_input = torch.tensor([BOS] * self.beam_size, dtype=torch.long, device=device)
decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)
decoder_hidden = encoder_hidden
# beam_sizeの数だけrepeatする
decoder_input = decoder_input.expand(1, beam_size)
decoder_hidden = decoder_hidden.expand(1, beam_size, -1).contiguous()
k = beam_size
finished_beams = []
finished_scores = []
prev_probs = torch.zeros(beam_size, 1, dtype=torch.float, device=device) # 前の時刻の各ビームの対数尤度を保持しておく
output_size = self.decoder.output_size
# 各時刻ごとに処理
for t in range(max_length):
# decoder_input: (1, k)
decoder_output, decoder_hidden = self.decoder(decoder_input[-1:], decoder_hidden)
# decoder_output: (1, k, output_size)
# decoder_hidden: (1, k, hidden_size)
decoder_output_t = decoder_output[-1] # (k, output_size)
log_probs = prev_probs + F.log_softmax(decoder_output_t, dim=-1) # (k, output_size)
scores = log_probs # 対数尤度をスコアとする
# スコアの高いビームとその単語を取得
flat_scores = scores.view(-1) # (k*output_size,)
if t == 0:
flat_scores = flat_scores[:output_size] # t=0のときは後半の同じ値の繰り返しを除外
top_vs, top_is = flat_scores.data.topk(k)
beam_indices = top_is / output_size # (k,)
word_indices = top_is % output_size # (k,)
# ビームを更新する
_next_beam_indices = []
_next_word_indices = []
for b, w in zip(beam_indices, word_indices):
if w.item() == EOS: # EOSに到達した場合はそのビームは更新して終了
k -= 1
beam = torch.cat([decoder_input.t()[b], w.view(1,)]) # (t+2,)
score = scores[b, w].item()
finished_beams.append(beam)
finished_scores.append(score)
else: # それ以外の場合はビームを更新
_next_beam_indices.append(b)
_next_word_indices.append(w)
if k == 0:
break
# tensornに変換
next_beam_indices = torch.tensor(_next_beam_indices, device=device)
next_word_indices = torch.tensor(_next_word_indices, device=device)
# 次の時刻のDecoderの入力を更新
decoder_input = torch.index_select(
decoder_input, dim=-1, index=next_beam_indices)
decoder_input = torch.cat(
[decoder_input, next_word_indices.unsqueeze(0)], dim=0)
# 次の時刻のDecoderの隠れ層を更新
decoder_hidden = torch.index_select(
decoder_hidden, dim=1, index=next_beam_indices)
# 各ビームの対数尤度を更新
flat_probs = log_probs.view(-1) # (k*output_size,)
next_indices = (next_beam_indices + 1) * next_word_indices
prev_probs = torch.index_select(
flat_probs, dim=0, index=next_indices).unsqueeze(1) # (k, 1)
# すべてのビームが完了したらデータを整形
decoder_outputs = [[idx.item() for idx in beam[1:-1]] for beam in finished_beams]
return decoder_outputs, finished_scores
# 学習済みモデルの読み込み
beam_size = 3
beam_model = BeamEncoderDecoder(**model_args, beam_size=beam_size).to(device)
beam_model.load_state_dict(ckpt)
beam_model.eval()
test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)
# 生成
batch_X, batch_Y, lengths_X = next(test_dataloader)
sentence_X = ' '.join(ids_to_sentence(vocab_X, batch_X.data.cpu().numpy()[:-1, 0]))
sentence_Y = ' '.join(ids_to_sentence(vocab_Y, batch_Y.data.cpu().numpy()[:-1, 0]))
print('src: {}'.format(sentence_X))
print('tgt: {}'.format(sentence_Y))
# 普通のdecode
output = model(batch_X, lengths_X, max_length=20)
output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()
output_sentence = ' '.join(ids_to_sentence(vocab_Y, trim_eos(output)))
print('out: {}'.format(output_sentence))
# beam decode
outputs, scores = beam_model(batch_X, lengths_X, max_length=20)
# scoreの良い順にソート
outputs, scores = zip(*sorted(zip(outputs, scores), key=lambda x: -x[1]))
for o, output in enumerate(outputs):
output_sentence = ' '.join(ids_to_sentence(vocab_Y, output))
print('out{}: {}'.format(o+1, output_sentence))
TeacherForcingを使うことによって、誤差が極端に拡大してしまうことを防ぐ。 これを使った場合はデコーダの出力を用いずに教師データを使うので誤りがなくなる しかしこれだけに頼ってしまうと、学習時と実用時の環境がかけ離れてしまい、推論を行うときに汎化性能のないモデルになってしまうため、一定の割合に制限して使用するようにしている。 データローダーの仕組みは重要。 こういう仕組みを使わないと、膨大なデータをすべて前処理した状態のものをメモリに一括で載せないといけなくなってしまい、メモリが大抵足らなくなる。
# 実行結果
'''
src: he lived a hard life .
tgt: 彼 は つら い 人生 を 送 っ た 。
out: 彼 は 人生 を 生活 を 送 っ た 。
src: no . i 'm sorry , i 've got to go back early .
tgt: ごめん なさ い 。 早 く 帰 ら な く ちゃ 。
out: いいえ 、 帰 っ たら 、 行 き き き だ 。
src: she wrote to me to come at once .
tgt: 彼女 は 私 に すぐ 来 い と の 便り を よこ し た 。
out: 彼女 は すぐ に 来る と に し た た 。
src: i can 't swim at all .
tgt: 私 は 少し も 泳げ な い 。
out: 私 は 泳げ 泳げ な い 。
src: is there any hope of his success ?
tgt: 彼 の 成功 の 見込み は あ り ま す か 。
out: 彼 の 成功 は どう か あ り ま す か 。
src: i 'll pick him up at 5 .
tgt: 私 は 5 時 に 彼 を 迎え に 行 く つもり で す 。
out: 私 は 彼 を に に を し ま 。 。
src: it 's so lovely a day .
tgt: 本当 に い い 天気 だ 。
out: それ 一 日 で す 。
src: show your own business .
tgt: 自分 の 事 を しろ 。
out: 君 の 商売 を <UNK> し て い 。
src: i study english every day .
tgt: 私 は 毎日 英語 の 勉強 を する 。
out: 私 は 毎日 英語 を 勉強 し ま す 。
src: i like spring the best of the seasons .
tgt: 私 は 季節 の 中 で 春 が 好き だ 。
out: 私 は 一番 が が 一番 好き が 一番 が が 好き 好き で 。
src: you can 't have lost your coat in the house .
tgt: 家 の 中 で コート が 無 くな る はず は な い 。
out: コート を コート を コート を の は で で な な な い 。 。
src: there are some oranges on the tree .
tgt: 木 に オレンジ が いく つ か な っ て い る 。
out: その の 木 の 木 木 が が い い い 。
src: she died of shock .
tgt: 彼女 は ショック 死 し た 。
out: 彼女 は 死 ん で い た 。
src: i like french food very much .
tgt: 私 は フランス 料理 が 好き で す 。
out: 私 は フランス 語 が 好き で す 。
src: we went to boston , where we stayed a week .
tgt: 私 たち は ボストン に 行 き 、 そこ に 一 週間 滞在 し た 。
out: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た
'''
# 惜しいものもいくつかあるが、かなりわけのわからないものもある。
# おなじ文字や単語が複数連続して出力されてしまうのが特徴的である。