TransformerはRNNやCNNを使用せず、Attentionのみを用いるSeq2Seqモデルです。
並列計算が可能なためRNNに比べて計算が高速な上、Self-Attentionと呼ばれる機構を用いることにより、局所的な位置しか参照できないCNNと異なり、系列内の任意の位置の情報を参照することを可能にしています。
その他にもいくつかの工夫が加えられており、翻訳に限らない自然言語処理のあらゆるタスクで圧倒的な性能を示すことが知られています。
参考実装:https://github.com/jadore801120/attention-is-all-you-need-pytorch
! wget https://www.dropbox.com/s/9narw5x4uizmehh/utils.py
! mkdir images data
# data取得
! wget https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en -P data/
! wget https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja -P data/
! wget https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en -P data/
! wget https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja -P data/
! wget https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en -P data/
! wget https://www.dropbox.com/s/ak53qirssci6f1j/train.ja -P data/
--2021-07-18 12:08:07-- https://www.dropbox.com/s/9narw5x4uizmehh/utils.py Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/9narw5x4uizmehh/utils.py [following] --2021-07-18 12:08:07-- https://www.dropbox.com/s/raw/9narw5x4uizmehh/utils.py Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com/cd/0/inline/BSjM4QxI-YkvO4k_X5FoMHZokPNnkyJRhKeIDkIq-EFFjxqkJEjo-ygKrclBmTTVsG3GrFPvhfX-fk4ZKjm-4n0mn2Mo-8KsGEx5NDBX0aSB5T4M0j5xVdtkuwsWnUMWMiAOM8CQgi8rjUojfBkXnWxk/file# [following] --2021-07-18 12:08:07-- https://uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com/cd/0/inline/BSjM4QxI-YkvO4k_X5FoMHZokPNnkyJRhKeIDkIq-EFFjxqkJEjo-ygKrclBmTTVsG3GrFPvhfX-fk4ZKjm-4n0mn2Mo-8KsGEx5NDBX0aSB5T4M0j5xVdtkuwsWnUMWMiAOM8CQgi8rjUojfBkXnWxk/file Resolving uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com (uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com)... 162.125.4.15, 2620:100:6017:15::a27d:20f Connecting to uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com (uc0e4ec689d9019c5836c5a442e0.dl.dropboxusercontent.com)|162.125.4.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 949 [text/plain] Saving to: ‘utils.py.1’ utils.py.1 100%[===================>] 949 --.-KB/s in 0s 2021-07-18 12:08:07 (176 MB/s) - ‘utils.py.1’ saved [949/949] mkdir: cannot create directory ‘images’: File exists mkdir: cannot create directory ‘data’: File exists --2021-07-18 12:08:08-- https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/o4kyc52a8we25wy/dev.en [following] --2021-07-18 12:08:08-- https://www.dropbox.com/s/raw/o4kyc52a8we25wy/dev.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com/cd/0/inline/BSjWZRTKeJDRVIcWMLjUqUmg0EF0yMaGnTIXrDOKCZCUwfYRwtTa2Yuwy7KsdR6H8rCrhO54D90_TiOVC1KlZOPdUDdlQYbG3lehuh13Vag_ZFXFXk7NeJDUbCzPu9Xzic5pfKHkTwelk4snaZTZOWiE/file# [following] --2021-07-18 12:08:08-- https://uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com/cd/0/inline/BSjWZRTKeJDRVIcWMLjUqUmg0EF0yMaGnTIXrDOKCZCUwfYRwtTa2Yuwy7KsdR6H8rCrhO54D90_TiOVC1KlZOPdUDdlQYbG3lehuh13Vag_ZFXFXk7NeJDUbCzPu9Xzic5pfKHkTwelk4snaZTZOWiE/file Resolving uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com (uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com)... 162.125.2.15, 2620:100:6022:15::a27d:420f Connecting to uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com (uc69432a490bf7bca59d9efd0a8d.dl.dropboxusercontent.com)|162.125.2.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 17054 (17K) [text/plain] Saving to: ‘data/dev.en.1’ dev.en.1 100%[===================>] 16.65K --.-KB/s in 0.001s 2021-07-18 12:08:08 (15.6 MB/s) - ‘data/dev.en.1’ saved [17054/17054] --2021-07-18 12:08:08-- https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/kdgskm5hzg6znuc/dev.ja [following] --2021-07-18 12:08:08-- https://www.dropbox.com/s/raw/kdgskm5hzg6znuc/dev.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com/cd/0/inline/BShJS8UndOASy5GQQSPsXjT-MMUHly5DIb6KV0kZrmfT4A_b9fPxvLeKehfibExx7MG1ys6oVBrg8-jcYufSRON374Tzx_6v5p94c-gNLlB_x6lWBvI78qv4IlWecRnMHSbgOaqPp5xnAlcwakRE-quL/file# [following] --2021-07-18 12:08:08-- https://uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com/cd/0/inline/BShJS8UndOASy5GQQSPsXjT-MMUHly5DIb6KV0kZrmfT4A_b9fPxvLeKehfibExx7MG1ys6oVBrg8-jcYufSRON374Tzx_6v5p94c-gNLlB_x6lWBvI78qv4IlWecRnMHSbgOaqPp5xnAlcwakRE-quL/file Resolving uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com (uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com)... 162.125.4.15, 2620:100:6022:15::a27d:420f Connecting to uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com (uc48ad60be0b228ded9d109fc845.dl.dropboxusercontent.com)|162.125.4.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 27781 (27K) [text/plain] Saving to: ‘data/dev.ja.1’ dev.ja.1 100%[===================>] 27.13K --.-KB/s in 0.07s 2021-07-18 12:08:09 (370 KB/s) - ‘data/dev.ja.1’ saved [27781/27781] --2021-07-18 12:08:09-- https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/gyyx4gohv9v65uh/test.en [following] --2021-07-18 12:08:09-- https://www.dropbox.com/s/raw/gyyx4gohv9v65uh/test.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com/cd/0/inline/BSjWxDLfydpQZeBMD313_Q8p6fbw5BknPyCeefaJvJ4gyEkcM8ttsW2aVa-zfLtXGwGgp4o5LiFEgjgl5EcynVdCuZvG3F0tRjSmbe_BykLTiZ73wXWAvIuvtRBNajE4ZXGd3ZFZjTN01MAkqZl6E9za/file# [following] --2021-07-18 12:08:09-- https://ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com/cd/0/inline/BSjWxDLfydpQZeBMD313_Q8p6fbw5BknPyCeefaJvJ4gyEkcM8ttsW2aVa-zfLtXGwGgp4o5LiFEgjgl5EcynVdCuZvG3F0tRjSmbe_BykLTiZ73wXWAvIuvtRBNajE4ZXGd3ZFZjTN01MAkqZl6E9za/file Resolving ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com (ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com)... 162.125.2.15, 2620:100:6017:15::a27d:20f Connecting to ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com (ucb8d14c3cbc1cbce1012d5c483d.dl.dropboxusercontent.com)|162.125.2.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 17301 (17K) [text/plain] Saving to: ‘data/test.en.1’ test.en.1 100%[===================>] 16.90K --.-KB/s in 0.001s 2021-07-18 12:08:10 (11.8 MB/s) - ‘data/test.en.1’ saved [17301/17301] --2021-07-18 12:08:10-- https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/hotxwbgoe2n013k/test.ja [following] --2021-07-18 12:08:10-- https://www.dropbox.com/s/raw/hotxwbgoe2n013k/test.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com/cd/0/inline/BSgTviouSpD9VGsNL8yCbZdFlqLeah68bzs-zu1qWajVW8dTfh2dm_KyrbZU7b_sNf7GCn46U1tZ55HuTmhd848pZT61O36vPpFRR-DIklYbM91pFbHxTPJh-7C00U_99FswPq5ShDEB-lfg2m08wNzY/file# [following] --2021-07-18 12:08:10-- https://uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com/cd/0/inline/BSgTviouSpD9VGsNL8yCbZdFlqLeah68bzs-zu1qWajVW8dTfh2dm_KyrbZU7b_sNf7GCn46U1tZ55HuTmhd848pZT61O36vPpFRR-DIklYbM91pFbHxTPJh-7C00U_99FswPq5ShDEB-lfg2m08wNzY/file Resolving uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com (uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com)... 162.125.4.15, 2620:100:6017:15::a27d:20f Connecting to uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com (uc8a502e4ecc5cc71aa3b2a82d48.dl.dropboxusercontent.com)|162.125.4.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 27793 (27K) [text/plain] Saving to: ‘data/test.ja.1’ test.ja.1 100%[===================>] 27.14K --.-KB/s in 0.07s 2021-07-18 12:08:11 (370 KB/s) - ‘data/test.ja.1’ saved [27793/27793] --2021-07-18 12:08:11-- https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/5lsftkmb20ay9e1/train.en [following] --2021-07-18 12:08:11-- https://www.dropbox.com/s/raw/5lsftkmb20ay9e1/train.en Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com/cd/0/inline/BSiY7o32yWwPhAOtzpjqM73yDUgwxtaIxCJ4JNcCh8LPNYG7NcBoD8VzAPYo88OU6x47wori6x_CVKoa_FsSQuE4MTocf_GhYLOtba2tr5NdKUvVFDQ3Bq7qcX_lqowbKoSOxAgiV_5zwU-jHZRB_hh6/file# [following] --2021-07-18 12:08:11-- https://ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com/cd/0/inline/BSiY7o32yWwPhAOtzpjqM73yDUgwxtaIxCJ4JNcCh8LPNYG7NcBoD8VzAPYo88OU6x47wori6x_CVKoa_FsSQuE4MTocf_GhYLOtba2tr5NdKUvVFDQ3Bq7qcX_lqowbKoSOxAgiV_5zwU-jHZRB_hh6/file Resolving ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com (ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com)... 162.125.2.15, 2620:100:6017:15::a27d:20f Connecting to ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com (ucaf21ac6da73708f845506de743.dl.dropboxusercontent.com)|162.125.2.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1701356 (1.6M) [text/plain] Saving to: ‘data/train.en.1’ train.en.1 100%[===================>] 1.62M 10.8MB/s in 0.2s 2021-07-18 12:08:11 (10.8 MB/s) - ‘data/train.en.1’ saved [1701356/1701356] --2021-07-18 12:08:11-- https://www.dropbox.com/s/ak53qirssci6f1j/train.ja Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212 Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/ak53qirssci6f1j/train.ja [following] --2021-07-18 12:08:12-- https://www.dropbox.com/s/raw/ak53qirssci6f1j/train.ja Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com/cd/0/inline/BSjOUBH4vUgdjmnAtTH8v1AlA62Kd3MY-H414kKwrEAjHhKiVisM6k4KIPoLM3sFQIhdY7tO3t8ZEMAFeI44Un_7sGFgBljiVq3uWN7WSeaxNbFntjlRv6eueoKJuCSvUbM5HT2HZ_xkh3uUlKCC8bEE/file# [following] --2021-07-18 12:08:12-- https://uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com/cd/0/inline/BSjOUBH4vUgdjmnAtTH8v1AlA62Kd3MY-H414kKwrEAjHhKiVisM6k4KIPoLM3sFQIhdY7tO3t8ZEMAFeI44Un_7sGFgBljiVq3uWN7WSeaxNbFntjlRv6eueoKJuCSvUbM5HT2HZ_xkh3uUlKCC8bEE/file Resolving uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com (uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com)... 162.125.2.15, 2620:100:6017:15::a27d:20f Connecting to uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com (uce89486d5d0b1d7bfce195839fb.dl.dropboxusercontent.com)|162.125.2.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 2784447 (2.7M) [text/plain] Saving to: ‘data/train.ja.1’ train.ja.1 100%[===================>] 2.66M --.-KB/s in 0.06s 2021-07-18 12:08:12 (43.8 MB/s) - ‘data/train.ja.1’ saved [2784447/2784447]
import time
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from nltk import bleu_score
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import Vocab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
random_state = 42
print(torch.__version__)
1.9.0+cu102
PAD = 0
UNK = 1
BOS = 2
EOS = 3
PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'
BOS_TOKEN = '<S>'
EOS_TOKEN = '</S>'
def load_data(file_path):
"""
テキストファイルからデータを読み込む
:param file_path: str, テキストファイルのパス
:return data: list, 文章(単語のリスト)のリスト
"""
data = []
for line in open(file_path, encoding='utf-8'):
words = line.strip().split() # スペースで単語を分割
data.append(words)
return data
train_X = load_data('./data/train.en')
train_Y = load_data('./data/train.ja')
# 訓練データと検証データに分割
train_X, valid_X, train_Y, valid_Y = train_test_split(train_X, train_Y, test_size=0.2, random_state=random_state)
# データセットの中身を確認
print('train_X:', train_X[:5])
print('train_Y:', train_Y[:5])
train_X: [['where', 'shall', 'we', 'eat', 'tonight', '?'], ['i', 'made', 'a', 'big', 'mistake', 'in', 'choosing', 'my', 'wife', '.'], ['i', "'ll", 'have', 'to', 'think', 'about', 'it', '.'], ['it', 'is', 'called', 'a', 'lily', '.'], ['could', 'you', 'lend', 'me', 'some', 'money', 'until', 'this', 'weekend', '?']] train_Y: [['今夜', 'は', 'どこ', 'で', '食事', 'を', 'し', 'よ', 'う', 'か', '。'], ['僕', 'は', '妻', 'を', '選', 'ぶ', 'の', 'に', '大変', 'な', '間違い', 'を', 'し', 'た', '。'], ['考え', 'と', 'く', 'よ', '。'], ['lily', 'と', '呼', 'ば', 'れ', 'て', 'い', 'ま', 'す', '。'], ['今週末', 'まで', 'いくら', 'か', '金', 'を', '貸', 'し', 'て', 'くれ', 'ま', 'せ', 'ん', 'か', '。']]
MIN_COUNT = 2 # 語彙に含める単語の最低出現回数
word2id = {
PAD_TOKEN: PAD,
BOS_TOKEN: BOS,
EOS_TOKEN: EOS,
UNK_TOKEN: UNK,
}
vocab_X = Vocab(word2id=word2id)
vocab_Y = Vocab(word2id=word2id)
vocab_X.build_vocab(train_X, min_count=MIN_COUNT)
vocab_Y.build_vocab(train_Y, min_count=MIN_COUNT)
vocab_size_X = len(vocab_X.id2word)
vocab_size_Y = len(vocab_Y.id2word)
def sentence_to_ids(vocab, sentence):
"""
単語のリストをインデックスのリストに変換する
:param vocab: Vocabのインスタンス
:param sentence: list of str
:return indices: list of int
"""
ids = [vocab.word2id.get(word, UNK) for word in sentence]
ids = [BOS] + ids + [EOS] # EOSを末尾に加える
return ids
train_X = [sentence_to_ids(vocab_X, sentence) for sentence in train_X]
train_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in train_Y]
valid_X = [sentence_to_ids(vocab_X, sentence) for sentence in valid_X]
valid_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in valid_Y]
class DataLoader(object):
def __init__(self, src_insts, tgt_insts, batch_size, shuffle=True):
"""
:param src_insts: list, 入力言語の文章(単語IDのリスト)のリスト
:param tgt_insts: list, 出力言語の文章(単語IDのリスト)のリスト
:param batch_size: int, バッチサイズ
:param shuffle: bool, サンプルの順番をシャッフルするか否か
"""
self.data = list(zip(src_insts, tgt_insts))
self.batch_size = batch_size
self.shuffle = shuffle
self.start_index = 0
self.reset()
def reset(self):
if self.shuffle:
self.data = shuffle(self.data, random_state=random_state)
self.start_index = 0
def __iter__(self):
return self
def __next__(self):
def preprocess_seqs(seqs):
# パディング
max_length = max([len(s) for s in seqs])
data = [s + [PAD] * (max_length - len(s)) for s in seqs]
# 単語の位置を表現するベクトルを作成
positions = [[pos+1 if w != PAD else 0 for pos, w in enumerate(seq)] for seq in data]
# テンソルに変換
data_tensor = torch.tensor(data, dtype=torch.long, device=device)
position_tensor = torch.tensor(positions, dtype=torch.long, device=device)
return data_tensor, position_tensor
# ポインタが最後まで到達したら初期化する
if self.start_index >= len(self.data):
self.reset()
raise StopIteration()
# バッチを取得して前処理
src_seqs, tgt_seqs = zip(*self.data[self.start_index:self.start_index+self.batch_size])
src_data, src_pos = preprocess_seqs(src_seqs)
tgt_data, tgt_pos = preprocess_seqs(tgt_seqs)
# ポインタを更新する
self.start_index += self.batch_size
return (src_data, src_pos), (tgt_data, tgt_pos)
TransformerのモデルもEncoder-Decoderモデルの構造になっています。 EncoderとDecoderは
など、いくつかのモジュールから構成されているため、それぞれのモジュールを個別に定義していきます。
Transformerは系列の処理にRNNを使用しないので、そのままでは単語列の語順を考慮することができません。
そのため、入力系列の埋め込み行列に単語の位置情報を埋め込むPosition Encodingを加算します。
Positional Encodingの行列PEの各成分は次式で表されます。 PE(pos,2i)=sin(pos/100002i/dmodel)
PE(pos,2i+1)=cos(pos/100002i/dmodel) ここでposは単語の位置を、iは成分の次元を表しています。
Positional Encodingの各成分は、波長が2πから10000∗2πに幾何学的に伸びる正弦波に対応します。
def position_encoding_init(n_position, d_pos_vec):
"""
Positional Encodingのための行列の初期化を行う
:param n_position: int, 系列長
:param d_pos_vec: int, 隠れ層の次元数
:return torch.tensor, size=(n_position, d_pos_vec)
"""
# PADがある単語の位置はpos=0にしておき、position_encも0にする
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)]
if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return torch.tensor(position_enc, dtype=torch.float)
ちなみに、Position Encodingを可視化すると以下のようになります。
pe = position_encoding_init(50, 256).numpy()
plt.figure(figsize=(16,8))
sns.heatmap(pe, cmap='Blues')
plt.show()
縦軸が単語の位置を、横軸が成分の次元を表しており、濃淡が加算される値です。
ここでは最大系列長を50、隠れ層の次元数を256としました。
ソース・ターゲット注意機構と自己注意機構 Attentionは一般に、queryベクトルとkeyベクトルの類似度を求めて、その正規化した重みをvalueベクトルに適用して値を取り出す処理を行います。
一般的な翻訳モデルで用いられるAttentionはソース・ターゲット注意機構と呼ばれ、この場合queryはDecoderの隠れ状態(Target)、keyはEncoderの隠れ状態(Source)、valueもEncoderの隠れ状態(Source)で表現されるのが一般的です。モデル全体の図では、右側のDecoderブロックの中央にあるAttentionがこれに相当します。
Transformerでは、このソース・ターゲット注意機構に加えて、query,key,valueを同じ系列内で定義する自己注意機構を用います。これにより、ある単語位置の出力を求める際にあらゆる位置を参照できるため、局所的な位置しか参照できない畳み込み層よりも良い性能を発揮できると言われています。モデル全体の図では、左側のEncoderブロックと右側のDecoderブロックの下部にあるAttentionがこれに当たります。
Transformerでは、Scaled Dot-Product Attentionと呼ばれるAttentionを、複数のヘッドで並列に扱うMulti-Head Attentionによって、Source-Target-AttentionとSelf-Attentionを実現します。
Attentionには、注意の重みを隠れ層 1 つのフィードフォワードネットワークで求めるAdditive Attentionと、注意の重みを内積で求めるDot-Product Attentionが存在します。 一般に、Dot-Product Attentionのほうがパラメータが少なく高速であり、Transformerでもこちらを使います。
Tranformerではさらなる工夫として、query(Q)とkey(K)の内積をスケーリング因子 √dk で除算します。
Attention(Q,K,V)=softmax(QKT√dk)V
これは、dk(keyベクトルの次元数)が大きい場合に内積が大きくなりすぎて逆伝播のsoftmaxの勾配が極端に小さくなることを防ぐ役割を果たします。
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, attn_dropout=0.1):
"""
:param d_model: int, 隠れ層の次元数
:param attn_dropout: float, ドロップアウト率
"""
super(ScaledDotProductAttention, self).__init__()
self.temper = np.power(d_model, 0.5) # スケーリング因子
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v, attn_mask):
"""
:param q: torch.tensor, queryベクトル,
size=(n_head*batch_size, len_q, d_model/n_head)
:param k: torch.tensor, key,
size=(n_head*batch_size, len_k, d_model/n_head)
:param v: torch.tensor, valueベクトル,
size=(n_head*batch_size, len_v, d_model/n_head)
:param attn_mask: torch.tensor, Attentionに適用するマスク,
size=(n_head*batch_size, len_q, len_k)
:return output: 出力ベクトル,
size=(n_head*batch_size, len_q, d_model/n_head)
:return attn: Attention
size=(n_head*batch_size, len_q, len_k)
"""
# QとKの内積でAttentionの重みを求め、スケーリングする
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (n_head*batch_size, len_q, len_k)
# Attentionをかけたくない部分がある場合は、その部分を負の無限大に飛ばしてSoftmaxの値が0になるようにする
attn.data.masked_fill_(attn_mask, -float('inf'))
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
TransformerではAttentionを複数のヘッドで並列に行うMulti-Head Attentionを採用しています。
複数のヘッドでAttentionを行うことにより、各ヘッドが異なる部分空間を処理でき、精度が向上するとされています。
class MultiHeadAttention(nn.Module):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
"""
:param n_head: int, ヘッド数
:param d_model: int, 隠れ層の次元数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param dropout: float, ドロップアウト率
"""
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# 各ヘッドごとに異なる重みで線形変換を行うための重み
# nn.Parameterを使うことで、Moduleのパラメータとして登録できる. TFでは更新が必要な変数はtf.Variableでラップするのでわかりやすい
self.w_qs = nn.Parameter(torch.empty([n_head, d_model, d_k], dtype=torch.float))
self.w_ks = nn.Parameter(torch.empty([n_head, d_model, d_k], dtype=torch.float))
self.w_vs = nn.Parameter(torch.empty([n_head, d_model, d_v], dtype=torch.float))
# nn.init.xavier_normal_で重みの値を初期化
nn.init.xavier_normal_(self.w_qs)
nn.init.xavier_normal_(self.w_ks)
nn.init.xavier_normal_(self.w_vs)
self.attention = ScaledDotProductAttention(d_model)
self.layer_norm = nn.LayerNorm(d_model) # 各層においてバイアスを除く活性化関数への入力を平均0、分散1に正則化
self.proj = nn.Linear(n_head*d_v, d_model) # 複数ヘッド分のAttentionの結果を元のサイズに写像するための線形層
# nn.init.xavier_normal_で重みの値を初期化
nn.init.xavier_normal_(self.proj.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, attn_mask=None):
"""
:param q: torch.tensor, queryベクトル,
size=(batch_size, len_q, d_model)
:param k: torch.tensor, key,
size=(batch_size, len_k, d_model)
:param v: torch.tensor, valueベクトル,
size=(batch_size, len_v, d_model)
:param attn_mask: torch.tensor, Attentionに適用するマスク,
size=(batch_size, len_q, len_k)
:return outputs: 出力ベクトル,
size=(batch_size, len_q, d_model)
:return attns: Attention
size=(n_head*batch_size, len_q, len_k)
"""
d_k, d_v = self.d_k, self.d_v
n_head = self.n_head
# residual connectionのための入力 出力に入力をそのまま加算する
residual = q
batch_size, len_q, d_model = q.size()
batch_size, len_k, d_model = k.size()
batch_size, len_v, d_model = v.size()
# 複数ヘッド化
# torch.repeat または .repeatで指定したdimに沿って同じテンソルを作成
q_s = q.repeat(n_head, 1, 1) # (n_head*batch_size, len_q, d_model)
k_s = k.repeat(n_head, 1, 1) # (n_head*batch_size, len_k, d_model)
v_s = v.repeat(n_head, 1, 1) # (n_head*batch_size, len_v, d_model)
# ヘッドごとに並列計算させるために、n_headをdim=0に、batch_sizeをdim=1に寄せる
q_s = q_s.view(n_head, -1, d_model) # (n_head, batch_size*len_q, d_model)
k_s = k_s.view(n_head, -1, d_model) # (n_head, batch_size*len_k, d_model)
v_s = v_s.view(n_head, -1, d_model) # (n_head, batch_size*len_v, d_model)
# 各ヘッドで線形変換を並列計算(p16左側`Linear`)
q_s = torch.bmm(q_s, self.w_qs) # (n_head, batch_size*len_q, d_k)
k_s = torch.bmm(k_s, self.w_ks) # (n_head, batch_size*len_k, d_k)
v_s = torch.bmm(v_s, self.w_vs) # (n_head, batch_size*len_v, d_v)
# Attentionは各バッチ各ヘッドごとに計算させるためにbatch_sizeをdim=0に寄せる
q_s = q_s.view(-1, len_q, d_k) # (n_head*batch_size, len_q, d_k)
k_s = k_s.view(-1, len_k, d_k) # (n_head*batch_size, len_k, d_k)
v_s = v_s.view(-1, len_v, d_v) # (n_head*batch_size, len_v, d_v)
# Attentionを計算(p16.左側`Scaled Dot-Product Attention * h`)
outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head, 1, 1))
# 各ヘッドの結果を連結(p16左側`Concat`)
# torch.splitでbatch_sizeごとのn_head個のテンソルに分割
outputs = torch.split(outputs, batch_size, dim=0) # (batch_size, len_q, d_model) * n_head
# dim=-1で連結
outputs = torch.cat(outputs, dim=-1) # (batch_size, len_q, d_model*n_head)
# residual connectionのために元の大きさに写像(p16左側`Linear`)
outputs = self.proj(outputs) # (batch_size, len_q, d_model)
outputs = self.dropout(outputs)
outputs = self.layer_norm(outputs + residual)
return outputs, attns
単語列の位置ごとに独立して処理する2層のネットワークであるPosition-Wise Feed Forward Networkを定義します。
class PositionwiseFeedForward(nn.Module):
"""
:param d_hid: int, 隠れ層1層目の次元数
:param d_inner_hid: int, 隠れ層2層目の次元数
:param dropout: float, ドロップアウト率
"""
def __init__(self, d_hid, d_inner_hid, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
# window size 1のconv層を定義することでPosition wiseな全結合層を実現する.
self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1)
self.w_2 = nn.Conv1d(d_inner_hid, d_hid, 1)
self.layer_norm = nn.LayerNorm(d_hid)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
def forward(self, x):
"""
:param x: torch.tensor,
size=(batch_size, max_length, d_hid)
:return: torch.tensor,
size=(batch_size, max_length, d_hid)
"""
residual = x
output = self.relu(self.w_1(x.transpose(1, 2)))
output = self.w_2(output).transpose(2, 1)
output = self.dropout(output)
return self.layer_norm(output + residual)
def get_attn_padding_mask(seq_q, seq_k):
"""
keyのPADに対するattentionを0にするためのマスクを作成する
:param seq_q: tensor, queryの系列, size=(batch_size, len_q)
:param seq_k: tensor, keyの系列, size=(batch_size, len_k)
:return pad_attn_mask: tensor, size=(batch_size, len_q, len_k)
"""
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(PAD).unsqueeze(1) # (N, 1, len_k) PAD以外のidを全て0にする
pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k) # (N, len_q, len_k)
return pad_attn_mask
_seq_q = torch.tensor([[1, 2, 3]])
_seq_k = torch.tensor([[4, 5, 6, 7, PAD]])
_mask = get_attn_padding_mask(_seq_q, _seq_k) # 行がquery、列がkeyに対応し、key側がPAD(=0)の時刻だけ1で他が0の行列ができる
print('query:\n', _seq_q)
print('key:\n', _seq_k)
print('mask:\n', _mask)
query: tensor([[1, 2, 3]]) key: tensor([[4, 5, 6, 7, 0]]) mask: tensor([[[False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True]]])
もう一つはDecoder側でSelf Attentionを行う際に、各時刻で未来の情報に対するAttentionを行わないようにするマスクです。
def get_attn_subsequent_mask(seq):
"""
未来の情報に対するattentionを0にするためのマスクを作成する
:param seq: tensor, size=(batch_size, length)
:return subsequent_mask: tensor, size=(batch_size, length, length)
"""
attn_shape = (seq.size(1), seq.size(1))
# 上三角行列(diagonal=1: 対角線より上が1で下が0)
subsequent_mask = torch.triu(torch.ones(attn_shape, dtype=torch.uint8, device=device), diagonal=1)
subsequent_mask = subsequent_mask.repeat(seq.size(0), 1, 1)
return subsequent_mask
_seq = torch.tensor([[1,2,3,4]])
_mask = get_attn_subsequent_mask(_seq) # 行がquery、列がkeyに対応し、queryより未来のkeyの値が1で他は0の行列ができいる
print('seq:\n', _seq)
print('mask:\n', _mask)
seq: tensor([[1, 2, 3, 4]]) mask: tensor([[[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)
class EncoderLayer(nn.Module):
"""Encoderのブロックのクラス"""
def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
"""
:param d_model: int, 隠れ層の次元数
:param d_inner_hid: int, Position Wise Feed Forward Networkの隠れ層2層目の次元数
:param n_head: int, ヘッド数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param dropout: float, ドロップアウト率
"""
super(EncoderLayer, self).__init__()
# Encoder内のSelf-Attention
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
# Postionwise FFN
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
"""
:param enc_input: tensor, Encoderの入力,
size=(batch_size, max_length, d_model)
:param slf_attn_mask: tensor, Self Attentionの行列にかけるマスク,
size=(batch_size, len_q, len_k)
:return enc_output: tensor, Encoderの出力,
size=(batch_size, max_length, d_model)
:return enc_slf_attn: tensor, EncoderのSelf Attentionの行列,
size=(n_head*batch_size, len_q, len_k)
"""
# Self-Attentionのquery, key, valueにはすべてEncoderの入力(enc_input)が入る
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, attn_mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
class Encoder(nn.Module):
"""EncoderLayerブロックからなるEncoderのクラス"""
def __init__(
self, n_src_vocab, max_length, n_layers=6, n_head=8, d_k=64, d_v=64,
d_word_vec=512, d_model=512, d_inner_hid=1024, dropout=0.1):
"""
:param n_src_vocab: int, 入力言語の語彙数
:param max_length: int, 最大系列長
:param n_layers: int, レイヤー数
:param n_head: int, ヘッド数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param d_word_vec: int, 単語の埋め込みの次元数
:param d_model: int, 隠れ層の次元数
:param d_inner_hid: int, Position Wise Feed Forward Networkの隠れ層2層目の次元数
:param dropout: float, ドロップアウト率
"""
super(Encoder, self).__init__()
n_position = max_length + 1
self.max_length = max_length
self.d_model = d_model
# Positional Encodingを用いたEmbedding
self.position_enc = nn.Embedding(n_position, d_word_vec, padding_idx=PAD)
self.position_enc.weight.data = position_encoding_init(n_position, d_word_vec)
# 一般的なEmbedding
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=PAD)
# EncoderLayerをn_layers個積み重ねる
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner_hid, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
def forward(self, src_seq, src_pos):
"""
:param src_seq: tensor, 入力系列,
size=(batch_size, max_length)
:param src_pos: tensor, 入力系列の各単語の位置情報,
size=(batch_size, max_length)
:return enc_output: tensor, Encoderの最終出力,
size=(batch_size, max_length, d_model)
:return enc_slf_attns: list, EncoderのSelf Attentionの行列のリスト
"""
# 一般的な単語のEmbeddingを行う
enc_input = self.src_word_emb(src_seq)
# Positional EncodingのEmbeddingを加算する
enc_input += self.position_enc(src_pos)
enc_slf_attns = []
enc_output = enc_input
# key(=enc_input)のPADに対応する部分のみ1のマスクを作成
enc_slf_attn_mask = get_attn_padding_mask(src_seq, src_seq)
# n_layers個のEncoderLayerに入力を通す
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(
enc_output, slf_attn_mask=enc_slf_attn_mask)
enc_slf_attns += [enc_slf_attn]
return enc_output, enc_slf_attns
Deocoderも同様にSelf Attention, Source-Target Attention, Position-Wise Feed Forward Networkからなるブロックを複数層繰り返ので、ブロックのクラスDecoderLayerを定義した後にDecoderを定義します。
class DecoderLayer(nn.Module):
"""Decoderのブロックのクラス"""
def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
"""
:param d_model: int, 隠れ層の次元数
:param d_inner_hid: int, Position Wise Feed Forward Networkの隠れ層2層目の次元数
:param n_head: int, ヘッド数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param dropout: float, ドロップアウト率
"""
super(DecoderLayer, self).__init__()
# Decoder内のSelf-Attention
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
# Encoder-Decoder間のSource-Target Attention
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
# Positionwise FFN
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
"""
:param dec_input: tensor, Decoderの入力,
size=(batch_size, max_length, d_model)
:param enc_output: tensor, Encoderの出力,
size=(batch_size, max_length, d_model)
:param slf_attn_mask: tensor, Self Attentionの行列にかけるマスク,
size=(batch_size, len_q, len_k)
:param dec_enc_attn_mask: tensor, Soutce-Target Attentionの行列にかけるマスク,
size=(batch_size, len_q, len_k)
:return dec_output: tensor, Decoderの出力,
size=(batch_size, max_length, d_model)
:return dec_slf_attn: tensor, DecoderのSelf Attentionの行列,
size=(n_head*batch_size, len_q, len_k)
:return dec_enc_attn: tensor, DecoderのSoutce-Target Attentionの行列,
size=(n_head*batch_size, len_q, len_k)
"""
# Self-Attentionのquery, key, valueにはすべてDecoderの入力(dec_input)が入る
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, attn_mask=slf_attn_mask)
# Source-Target-AttentionのqueryにはDecoderの出力(dec_output), key, valueにはEncoderの出力(enc_output)が入る
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn
class Decoder(nn.Module):
"""DecoderLayerブロックからなるDecoderのクラス"""
def __init__(
self, n_tgt_vocab, max_length, n_layers=6, n_head=8, d_k=64, d_v=64,
d_word_vec=512, d_model=512, d_inner_hid=1024, dropout=0.1):
"""
:param n_tgt_vocab: int, 出力言語の語彙数
:param max_length: int, 最大系列長
:param n_layers: int, レイヤー数
:param n_head: int, ヘッド数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param d_word_vec: int, 単語の埋め込みの次元数
:param d_model: int, 隠れ層の次元数
:param d_inner_hid: int, Position Wise Feed Forward Networkの隠れ層2層目の次元数
:param dropout: float, ドロップアウト率
"""
super(Decoder, self).__init__()
n_position = max_length + 1
self.max_length = max_length
self.d_model = d_model
# Positional Encodingを用いたEmbedding
self.position_enc = nn.Embedding(
n_position, d_word_vec, padding_idx=PAD)
self.position_enc.weight.data = position_encoding_init(n_position, d_word_vec)
# 一般的なEmbedding
self.tgt_word_emb = nn.Embedding(
n_tgt_vocab, d_word_vec, padding_idx=PAD)
self.dropout = nn.Dropout(dropout)
# DecoderLayerをn_layers個積み重ねる
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner_hid, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
def forward(self, tgt_seq, tgt_pos, src_seq, enc_output):
"""
:param tgt_seq: tensor, 出力系列,
size=(batch_size, max_length)
:param tgt_pos: tensor, 出力系列の各単語の位置情報,
size=(batch_size, max_length)
:param src_seq: tensor, 入力系列,
size=(batch_size, n_src_vocab)
:param enc_output: tensor, Encoderの出力,
size=(batch_size, max_length, d_model)
:return dec_output: tensor, Decoderの最終出力,
size=(batch_size, max_length, d_model)
:return dec_slf_attns: list, DecoderのSelf Attentionの行列のリスト
:return dec_slf_attns: list, DecoderのSelf Attentionの行列のリスト
"""
# 一般的な単語のEmbeddingを行う
dec_input = self.tgt_word_emb(tgt_seq)
# Positional EncodingのEmbeddingを加算する
dec_input += self.position_enc(tgt_pos)
# Self-Attention用のマスクを作成
# key(=dec_input)のPADに対応する部分が1のマスクと、queryから見たkeyの未来の情報に対応する部分が1のマスクのORをとる
dec_slf_attn_pad_mask = get_attn_padding_mask(tgt_seq, tgt_seq) # (N, max_length, max_length)
dec_slf_attn_sub_mask = get_attn_subsequent_mask(tgt_seq) # (N, max_length, max_length)
dec_slf_attn_mask = torch.gt(dec_slf_attn_pad_mask + dec_slf_attn_sub_mask, 0) # ORをとる
# key(=dec_input)のPADに対応する部分のみ1のマスクを作成
dec_enc_attn_pad_mask = get_attn_padding_mask(tgt_seq, src_seq) # (N, max_length, max_length)
dec_slf_attns, dec_enc_attns = [], []
dec_output = dec_input
# n_layers個のDecoderLayerに入力を通す
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output,
slf_attn_mask=dec_slf_attn_mask,
dec_enc_attn_mask=dec_enc_attn_pad_mask)
dec_slf_attns += [dec_slf_attn]
dec_enc_attns += [dec_enc_attn]
return dec_output, dec_slf_attns, dec_enc_attns
class Transformer(nn.Module):
"""Transformerのモデル全体のクラス"""
def __init__(
self, n_src_vocab, n_tgt_vocab, max_length, n_layers=6, n_head=8,
d_word_vec=512, d_model=512, d_inner_hid=1024, d_k=64, d_v=64,
dropout=0.1, proj_share_weight=True):
"""
:param n_src_vocab: int, 入力言語の語彙数
:param n_tgt_vocab: int, 出力言語の語彙数
:param max_length: int, 最大系列長
:param n_layers: int, レイヤー数
:param n_head: int, ヘッド数
:param d_k: int, keyベクトルの次元数
:param d_v: int, valueベクトルの次元数
:param d_word_vec: int, 単語の埋め込みの次元数
:param d_model: int, 隠れ層の次元数
:param d_inner_hid: int, Position Wise Feed Forward Networkの隠れ層2層目の次元数
:param dropout: float, ドロップアウト率
:param proj_share_weight: bool, 出力言語の単語のEmbeddingと出力の写像で重みを共有する
"""
super(Transformer, self).__init__()
self.encoder = Encoder(
n_src_vocab, max_length, n_layers=n_layers, n_head=n_head,
d_word_vec=d_word_vec, d_model=d_model,
d_inner_hid=d_inner_hid, dropout=dropout)
self.decoder = Decoder(
n_tgt_vocab, max_length, n_layers=n_layers, n_head=n_head,
d_word_vec=d_word_vec, d_model=d_model,
d_inner_hid=d_inner_hid, dropout=dropout)
self.tgt_word_proj = nn.Linear(d_model, n_tgt_vocab, bias=False)
nn.init.xavier_normal_(self.tgt_word_proj.weight)
self.dropout = nn.Dropout(dropout)
assert d_model == d_word_vec # 各モジュールの出力のサイズは揃える
if proj_share_weight:
# 出力言語の単語のEmbeddingと出力の写像で重みを共有する
assert d_model == d_word_vec
self.tgt_word_proj.weight = self.decoder.tgt_word_emb.weight
def get_trainable_parameters(self):
# Positional Encoding以外のパラメータを更新する
enc_freezed_param_ids = set(map(id, self.encoder.position_enc.parameters()))
dec_freezed_param_ids = set(map(id, self.decoder.position_enc.parameters()))
freezed_param_ids = enc_freezed_param_ids | dec_freezed_param_ids
return (p for p in self.parameters() if id(p) not in freezed_param_ids)
def forward(self, src, tgt):
src_seq, src_pos = src
tgt_seq, tgt_pos = tgt
src_seq = src_seq[:, 1:]
src_pos = src_pos[:, 1:]
tgt_seq = tgt_seq[:, :-1]
tgt_pos = tgt_pos[:, :-1]
enc_output, *_ = self.encoder(src_seq, src_pos)
dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
seq_logit = self.tgt_word_proj(dec_output)
return seq_logit
def compute_loss(batch_X, batch_Y, model, criterion, optimizer=None, is_train=True):
# バッチの損失を計算
model.train(is_train)
pred_Y = model(batch_X, batch_Y)
gold = batch_Y[0][:, 1:].contiguous()
# gold = batch_Y[0].contiguous()
loss = criterion(pred_Y.view(-1, pred_Y.size(2)), gold.view(-1))
if is_train: # 訓練時はパラメータを更新
optimizer.zero_grad()
loss.backward()
optimizer.step()
gold = gold.data.cpu().numpy().tolist()
pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().tolist()
return loss.item(), gold, pred
MAX_LENGTH = 20
batch_size = 64
num_epochs = 15
lr = 0.001
ckpt_path = 'transformer.pth'
max_length = MAX_LENGTH + 2
model_args = {
'n_src_vocab': vocab_size_X,
'n_tgt_vocab': vocab_size_Y,
'max_length': max_length,
'proj_share_weight': True,
'd_k': 32,
'd_v': 32,
'd_model': 128,
'd_word_vec': 128,
'd_inner_hid': 256,
'n_layers': 3,
'n_head': 6,
'dropout': 0.1,
}
# DataLoaderやモデルを定義
train_dataloader = DataLoader(
train_X, train_Y, batch_size
)
valid_dataloader = DataLoader(
valid_X, valid_Y, batch_size,
shuffle=False
)
model = Transformer(**model_args).to(device)
optimizer = optim.Adam(model.get_trainable_parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=PAD, size_average=False).to(device)
/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret))
def calc_bleu(refs, hyps):
"""
BLEUスコアを計算する関数
:param refs: list, 参照訳。単語のリストのリスト (例: [['I', 'have', 'a', 'pen'], ...])
:param hyps: list, モデルの生成した訳。単語のリストのリスト (例: [['I', 'have', 'a', 'pen'], ...])
:return: float, BLEUスコア(0~100)
"""
refs = [[ref[:ref.index(EOS)]] for ref in refs]
hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps]
return 100 * bleu_score.corpus_bleu(refs, hyps)
# 訓練
best_valid_bleu = 0.
for epoch in range(1, num_epochs+1):
start = time.time()
train_loss = 0.
train_refs = []
train_hyps = []
valid_loss = 0.
valid_refs = []
valid_hyps = []
# train
for batch in train_dataloader:
batch_X, batch_Y = batch
loss, gold, pred = compute_loss(
batch_X, batch_Y, model, criterion, optimizer, is_train=True
)
train_loss += loss
train_refs += gold
train_hyps += pred
# valid
for batch in valid_dataloader:
batch_X, batch_Y = batch
loss, gold, pred = compute_loss(
batch_X, batch_Y, model, criterion, is_train=False
)
valid_loss += loss
valid_refs += gold
valid_hyps += pred
# 損失をサンプル数で割って正規化
train_loss /= len(train_dataloader.data)
valid_loss /= len(valid_dataloader.data)
# BLEUを計算
train_bleu = calc_bleu(train_refs, train_hyps)
valid_bleu = calc_bleu(valid_refs, valid_hyps)
# validationデータでBLEUが改善した場合にはモデルを保存
if valid_bleu > best_valid_bleu:
ckpt = model.state_dict()
torch.save(ckpt, ckpt_path)
best_valid_bleu = valid_bleu
elapsed_time = (time.time()-start) / 60
print('Epoch {} [{:.1f}min]: train_loss: {:5.2f} train_bleu: {:2.2f} valid_loss: {:5.2f} valid_bleu: {:2.2f}'.format(
epoch, elapsed_time, train_loss, train_bleu, valid_loss, valid_bleu))
print('-'*80)
Epoch 1 [0.4min]: train_loss: 77.35 train_bleu: 4.68 valid_loss: 41.18 valid_bleu: 10.77 -------------------------------------------------------------------------------- Epoch 2 [0.3min]: train_loss: 39.40 train_bleu: 12.26 valid_loss: 32.24 valid_bleu: 17.36 -------------------------------------------------------------------------------- Epoch 3 [0.3min]: train_loss: 32.03 train_bleu: 18.01 valid_loss: 28.11 valid_bleu: 22.09 -------------------------------------------------------------------------------- Epoch 4 [0.4min]: train_loss: 28.30 train_bleu: 21.60 valid_loss: 25.78 valid_bleu: 24.94 -------------------------------------------------------------------------------- Epoch 5 [0.4min]: train_loss: 25.82 train_bleu: 24.42 valid_loss: 24.29 valid_bleu: 27.22 -------------------------------------------------------------------------------- Epoch 6 [0.3min]: train_loss: 23.99 train_bleu: 26.68 valid_loss: 23.03 valid_bleu: 29.07 -------------------------------------------------------------------------------- Epoch 7 [0.3min]: train_loss: 22.54 train_bleu: 28.64 valid_loss: 22.22 valid_bleu: 30.07 -------------------------------------------------------------------------------- Epoch 8 [0.4min]: train_loss: 21.37 train_bleu: 30.12 valid_loss: 21.53 valid_bleu: 31.17 -------------------------------------------------------------------------------- Epoch 9 [0.3min]: train_loss: 20.35 train_bleu: 31.42 valid_loss: 20.95 valid_bleu: 31.49 -------------------------------------------------------------------------------- Epoch 10 [0.3min]: train_loss: 19.42 train_bleu: 32.88 valid_loss: 20.39 valid_bleu: 33.22 -------------------------------------------------------------------------------- Epoch 11 [0.4min]: train_loss: 18.67 train_bleu: 33.86 valid_loss: 20.12 valid_bleu: 33.81 -------------------------------------------------------------------------------- Epoch 12 [0.4min]: train_loss: 17.93 train_bleu: 35.10 valid_loss: 19.78 valid_bleu: 33.89 -------------------------------------------------------------------------------- Epoch 13 [0.3min]: train_loss: 17.29 train_bleu: 36.01 valid_loss: 19.39 valid_bleu: 34.61 -------------------------------------------------------------------------------- Epoch 14 [0.3min]: train_loss: 16.71 train_bleu: 36.88 valid_loss: 19.14 valid_bleu: 35.31 -------------------------------------------------------------------------------- Epoch 15 [0.4min]: train_loss: 16.15 train_bleu: 37.92 valid_loss: 18.95 valid_bleu: 35.35 --------------------------------------------------------------------------------
def test(model, src, max_length=20):
# 学習済みモデルで系列を生成する
model.eval()
src_seq, src_pos = src
batch_size = src_seq.size(0)
enc_output, enc_slf_attns = model.encoder(src_seq, src_pos)
tgt_seq = torch.full([batch_size, 1], BOS, dtype=torch.long, device=device)
tgt_pos = torch.arange(1, dtype=torch.long, device=device)
tgt_pos = tgt_pos.unsqueeze(0).repeat(batch_size, 1)
# 時刻ごとに処理
for t in range(1, max_length+1):
dec_output, dec_slf_attns, dec_enc_attns = model.decoder(
tgt_seq, tgt_pos, src_seq, enc_output)
dec_output = model.tgt_word_proj(dec_output)
out = dec_output[:, -1, :].max(dim=-1)[1].unsqueeze(1)
# 自身の出力を次の時刻の入力にする
tgt_seq = torch.cat([tgt_seq, out], dim=-1)
tgt_pos = torch.arange(t+1, dtype=torch.long, device=device)
tgt_pos = tgt_pos.unsqueeze(0).repeat(batch_size, 1)
return tgt_seq[:, 1:], enc_slf_attns, dec_slf_attns, dec_enc_attns
def ids_to_sentence(vocab, ids):
# IDのリストを単語のリストに変換する
return [vocab.id2word[_id] for _id in ids]
def trim_eos(ids):
# IDのリストからEOS以降の単語を除外する
if EOS in ids:
return ids[:ids.index(EOS)]
else:
return ids
# 学習済みモデルの読み込み
model = Transformer(**model_args).to(device)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt)
<All keys matched successfully>
# テストデータの読み込み
test_X = load_data('./data/dev.en')
test_Y = load_data('./data/dev.ja')
test_X = [sentence_to_ids(vocab_X, sentence) for sentence in test_X]
test_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in test_Y]
test_dataloader = DataLoader(
test_X, test_Y, 1,
shuffle=False
)
src, tgt = next(test_dataloader)
src_ids = src[0][0].cpu().numpy()
tgt_ids = tgt[0][0].cpu().numpy()
print('src: {}'.format(' '.join(ids_to_sentence(vocab_X, src_ids[1:-1]))))
print('tgt: {}'.format(' '.join(ids_to_sentence(vocab_Y, tgt_ids[1:-1]))))
preds, enc_slf_attns, dec_slf_attns, dec_enc_attns = test(model, src)
pred_ids = preds[0].data.cpu().numpy().tolist()
print('out: {}'.format(' '.join(ids_to_sentence(vocab_Y, trim_eos(pred_ids)))))
src: show your own business . </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> tgt: 自分 の 事 を しろ 。 </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> out: 自分 の こと を <UNK> し て い た 。
# BLEUの評価
test_dataloader = DataLoader(
test_X, test_Y, 128,
shuffle=False
)
refs_list = []
hyp_list = []
for batch in test_dataloader:
batch_X, batch_Y = batch
preds, *_ = test(model, batch_X)
preds = preds.data.cpu().numpy().tolist()
refs = batch_Y[0].data.cpu().numpy()[:, 1:].tolist()
refs_list += refs
hyp_list += preds
bleu = calc_bleu(refs_list, hyp_list)
print(bleu)
24.052339616406652
# 何度か実施すると
### 考察
TransformerOnlyのモデルということで、前回のRNNありのモデルと比べて性能が向上することを期待していたが、
劇的に変わるということはなかった。
Attentionはどこまでも大きくスケールしていく、というような論文が2020年頃にGPTで有名なOpenAIから出されたということを思い出した。
つまり、このサンプルくらいのサイズで、小規模の田中コーパスだと、AttentionだけとかRNNありとかの違いはそれほど違いが出ないのではないかと予想される。
大規模な学習データで巨大なモデルで学習してこそ際立った性能が出るのではないか、という感想を持った。
'''
src: show your own business .
tgt: 自分 の 事 を しろ 。
out: 自分 の こと を <UNK> し て い た 。
src: the water was cut off yesterday . </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
tgt: 昨日 水道 を 止め られ た 。 </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
out: 昨日 水 は <UNK> を 切 っ た
src: i should like to see you this afternoon . </S> <PAD> <PAD> <PAD> <PAD>
tgt: 今日 の 午後 お 会 い し た い の で す が 。 </S> <PAD>
out: 午後 あなた に 会 う の を 見 る べ き だ 。
src: i gave him a call . </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
tgt: 私 は 彼 に 電話 を し た 。 </S> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
out: 彼 に 電話 を くれ た 。
'''