Simple example for Seq2Seq (Machine Translation) with Attention by Encoder Bi-directional RNN and Decoder RNN.
tf.data
padding technique
by user function (pad_seq)
tf.nn.embedding_lookup
for getting vector of tokens (eg. word, character)tf.contrib.seq2seq.sequence_loss
tf.sequence_mask
tf.contrib.seq2seq.LuongAttention
, tf.contrib.seq2seq.AttentionWrapper
tf.contrib.seq2seq.dynamic_decode
tf.contrib.seq2seq.TrainingHelper
tf.contrib.seq2seq.GreedyEmbeddingHelper
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import string
from pprint import pprint
%matplotlib inline
s2s = tf.contrib.seq2seq
print(tf.__version__)
1.10.0
sources = [['I', 'feel', 'hungry'],
['tensorflow', 'is', 'very', 'difficult'],
['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],
['tensorflow', 'is', 'very', 'fast', 'changing']]
targets = [['나는', '배가', '고프다'],
['텐서플로우는', '매우', '어렵다'],
['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'],
['텐서플로우는', '매우', '빠르게', '변화한다']]
# word dic for sentences
source_words = []
for elm in sources:
source_words += elm
source_words = list(set(source_words))
source_words.sort()
source_words = ['<pad>'] + source_words
source_dic = {word : idx for idx, word in enumerate(source_words)}
print(source_dic)
print(len(source_dic))
{'<pad>': 0, 'I': 1, 'a': 2, 'changing': 3, 'deep': 4, 'difficult': 5, 'fast': 6, 'feel': 7, 'for': 8, 'framework': 9, 'hungry': 10, 'is': 11, 'learning': 12, 'tensorflow': 13, 'very': 14} 15
source_idx_dic = {elm[1] : elm[0] for elm in source_dic.items()}
source_idx_dic
{0: '<pad>', 1: 'I', 2: 'a', 3: 'changing', 4: 'deep', 5: 'difficult', 6: 'fast', 7: 'feel', 8: 'for', 9: 'framework', 10: 'hungry', 11: 'is', 12: 'learning', 13: 'tensorflow', 14: 'very'}
# word dic for translations
target_words = []
for elm in targets:
target_words += elm
target_words = list(set(target_words))
target_words.sort()
target_words = ['<pad>']+ ['<start>'] + ['<end>'] + \
target_words # 번역문의 시작과 끝을 알리는 'start', 'end' token 추가
target_dic = {word : idx for idx, word in enumerate(target_words)}
print(target_dic)
print(len(target_dic))
{'<pad>': 0, '<start>': 1, '<end>': 2, '고프다': 3, '나는': 4, '딥러닝을': 5, '매우': 6, '배가': 7, '변화한다': 8, '빠르게': 9, '어렵다': 10, '위한': 11, '텐서플로우는': 12, '프레임워크이다': 13} 14
target_idx_dic = {elm[1] : elm[0] for elm in target_dic.items()}
target_idx_dic
{0: '<pad>', 1: '<start>', 2: '<end>', 3: '고프다', 4: '나는', 5: '딥러닝을', 6: '매우', 7: '배가', 8: '변화한다', 9: '빠르게', 10: '어렵다', 11: '위한', 12: '텐서플로우는', 13: '프레임워크이다'}
def pad_seq_enc(sequences, max_len, dic):
seq_len = []
seq_indices = []
for seq in sequences:
seq_len.append(len(seq))
seq_idx = [dic.get(word) for word in seq]
seq_idx += (max_len - len(seq_idx)) * [dic.get('<pad>')]
seq_indices.append(seq_idx)
return seq_len, seq_indices
def pad_seq_dec(sequences, max_len, dic):
seq_input_len = []
seq_input_indices = []
seq_target_indices = []
# for decoder input
for seq in sequences:
seq_input_idx = [dic.get('<start>')] + [dic.get(token) for token in seq]
seq_input_len.append(len(seq_input_idx))
seq_input_idx += (max_len - len(seq_input_idx)) * [dic.get('<pad>')]
seq_input_indices.append(seq_input_idx)
# for decoder output
for seq in sequences:
seq_target_idx = [dic.get(token) for token in seq] + [dic.get('<end>')]
seq_target_idx += (max_len - len(seq_target_idx)) * [dic.get('<pad>')]
seq_target_indices.append(seq_target_idx)
return seq_input_len, seq_input_indices, seq_target_indices
# for encoder
source_max_len = 10
X_length, X_indices = pad_seq_enc(sequences = sources, max_len = source_max_len, dic = source_dic)
print(X_length, np.shape(X_indices))
[3, 4, 7, 5] (4, 10)
# for decoder
target_max_len = 12
y_length, y_input_indices, y_target_indices = pad_seq_dec(sequences = targets, max_len = target_max_len,
dic = target_dic)
pprint(y_length)
pprint(y_input_indices)
pprint(y_target_indices)
[4, 4, 5, 5] [[1, 4, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0], [1, 12, 6, 10, 0, 0, 0, 0, 0, 0, 0, 0], [1, 12, 5, 11, 13, 0, 0, 0, 0, 0, 0, 0], [1, 12, 6, 9, 8, 0, 0, 0, 0, 0, 0, 0]] [[4, 7, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0], [12, 6, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0], [12, 5, 11, 13, 2, 0, 0, 0, 0, 0, 0, 0], [12, 6, 9, 8, 2, 0, 0, 0, 0, 0, 0, 0]]
Encoder RNN, Decoder RNN, Attention
class SimpleNMT:
def __init__(self, s_len, s_indices, t_len, t_input_indices, t_output_indices,
t_max_len = target_max_len, s_dic = source_dic, t_dic = target_dic,
n_of_classes = len(target_dic), enc_hdim = 8, dec_hdim = 4):
with tf.variable_scope('input_layer'):
# s : source, t : target
self._s_len = s_len
self._s_indices = s_indices
self._t_len = t_len
self._t_input_indices = t_input_indices
self._t_output_indices = t_output_indices
self._s_dic = s_dic
self._t_dic = t_dic
self._t_max_len = target_max_len
s_embeddings = tf.eye(num_rows = len(self._s_dic), dtype = tf.float32)
s_embeddings = tf.get_variable(name = 's_embeddings', initializer = s_embeddings,
trainable = False)
s_batch = tf.nn.embedding_lookup(params = s_embeddings, ids = self._s_indices)
with tf.variable_scope('encoder'):
enc_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = enc_hdim, activation = tf.nn.tanh)
enc_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = enc_hdim, activation = tf.nn.tanh)
enc_outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw = enc_fw_cell, cell_bw = enc_bw_cell,
inputs = s_batch, sequence_length = self._s_len,
dtype = tf.float32)
enc_outputs = tf.concat(values = [enc_outputs[0],enc_outputs[1]], axis = 2)
with tf.variable_scope('pipe'):
t_embeddings = tf.eye(num_rows = len(self._t_dic))
t_embeddings = tf.get_variable(name = 'embeddings',
initializer = t_embeddings,
trainable = False)
t_batch = tf.nn.embedding_lookup(params = t_embeddings, ids = self._t_input_indices)
batch_size = tf.reduce_sum(tf.ones_like(tensor = self._s_len))
tr_tokens = tf.tile(input = [self._t_max_len], multiples = [batch_size])
trans_tokens = tf.tile(input = [self._t_dic.get('<start>')], multiples = [batch_size])
with tf.variable_scope('decoder'):
dec_cell = tf.contrib.rnn.BasicRNNCell(num_units = dec_hdim, activation = tf.nn.tanh)
# Applying attention-mechanism
attn = s2s.LuongAttention(num_units = dec_hdim,
memory = enc_outputs,
memory_sequence_length = self._s_len, dtype = tf.float32)
attn_cell = s2s.AttentionWrapper(cell = dec_cell, attention_mechanism = attn)
dec_initial_state = attn_cell.zero_state(batch_size = batch_size, dtype = tf.float32)
output_layer = tf.layers.Dense(units = n_of_classes,
kernel_initializer = \
tf.contrib.layers.xavier_initializer(uniform = False))
with tf.variable_scope('training'):
tr_helper = s2s.TrainingHelper(inputs = t_batch,
sequence_length = tr_tokens)
tr_decoder = s2s.BasicDecoder(cell = attn_cell, helper = tr_helper,
initial_state = dec_initial_state,
output_layer = output_layer)
self._tr_outputs,_,_ = s2s.dynamic_decode(decoder = tr_decoder,
impute_finished = True,
maximum_iterations = self._t_max_len)
with tf.variable_scope('translation'):
trans_helper = s2s.GreedyEmbeddingHelper(embedding = t_embeddings,
start_tokens = trans_tokens,
end_token = self._t_dic.get('<end>'))
trans_decoder = s2s.BasicDecoder(cell = attn_cell, helper = trans_helper,
initial_state = dec_initial_state,
output_layer = output_layer)
self._trans_outputs, _, _ = s2s.dynamic_decode(decoder = trans_decoder,
impute_finished = True,
maximum_iterations = self._t_max_len * 2)
with tf.variable_scope('seq2seq_loss'):
masking = tf.sequence_mask(lengths = self._t_len,
maxlen = self._t_max_len, dtype = tf.float32)
self.__seq2seq_loss = s2s.sequence_loss(logits = self._tr_outputs.rnn_output,
targets = self._t_output_indices,
weights = masking)
def translate(self, sess, s_len, s_indices):
feed_translation = {self._s_len : s_len, self._s_indices : s_indices}
return sess.run(self._trans_outputs.sample_id, feed_dict = feed_translation)
@property
def loss(self):
return self.__seq2seq_loss
# hyper-parameter#
lr = .003
epochs = 500
batch_size = 2
total_step = int(np.shape(X_indices)[0] / batch_size)
print(total_step)
2
## create data pipeline with tf.data
tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y_length, y_input_indices, y_target_indices))
tr_dataset = tr_dataset.shuffle(buffer_size = 20)
tr_dataset = tr_dataset.batch(batch_size = batch_size)
tr_iterator = tr_dataset.make_initializable_iterator()
print(tr_dataset)
<BatchDataset shapes: ((?,), (?, 10), (?,), (?, 12), (?, 12)), types: (tf.int32, tf.int32, tf.int32, tf.int32, tf.int32)>
X_length_mb, X_indices_mb, y_length_mb, y_input_indices_mb, y_target_indices_mb = tr_iterator.get_next()
sim_nmt = SimpleNMT(s_len = X_length_mb, s_indices = X_indices_mb,
t_len = y_length_mb, t_input_indices = y_input_indices_mb,
t_output_indices = y_target_indices_mb)
WARNING:tensorflow:From /usr/local/var/pyenv/versions/3.6.5/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py:430: calling reverse_sequence (from tensorflow.python.ops.array_ops) with seq_dim is deprecated and will be removed in a future version. Instructions for updating: seq_dim is deprecated, use seq_axis instead WARNING:tensorflow:From /usr/local/var/pyenv/versions/3.6.5/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:454: calling reverse_sequence (from tensorflow.python.ops.array_ops) with batch_dim is deprecated and will be removed in a future version. Instructions for updating: batch_dim is deprecated, use batch_axis instead
## create training op
opt = tf.train.AdamOptimizer(learning_rate = lr)
training_op = opt.minimize(loss = sim_nmt.loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tr_loss_hist = []
for epoch in range(epochs):
avg_tr_loss = 0
tr_step = 0
sess.run(tr_iterator.initializer)
try:
while True:
_, tr_loss = sess.run(fetches = [training_op, sim_nmt.loss])
avg_tr_loss += tr_loss
tr_step += 1
except tf.errors.OutOfRangeError:
pass
avg_tr_loss /= tr_step
tr_loss_hist.append(avg_tr_loss)
if (epoch + 1) % 100 == 0:
print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))
epoch : 100, tr_loss : 0.215 epoch : 200, tr_loss : 0.056 epoch : 300, tr_loss : 0.026 epoch : 400, tr_loss : 0.015 epoch : 500, tr_loss : 0.010
yhat = sim_nmt.translate(sess = sess, s_len = X_length, s_indices = X_indices)
yhat
array([[ 4, 7, 3, 2, 0], [12, 6, 10, 2, 0], [12, 5, 11, 13, 2], [12, 6, 9, 8, 2]], dtype=int32)
# 원래 문장
originals = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], y_target_indices))
for original in originals:
print(original)
['나는', '배가', '고프다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'] ['텐서플로우는', '매우', '어렵다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'] ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'] ['텐서플로우는', '매우', '빠르게', '변화한다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
# 한글 넣은 번역문장
translations = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], yhat))
for translation in translations:
print(translation)
['나는', '배가', '고프다', '<end>', '<pad>'] ['텐서플로우는', '매우', '어렵다', '<end>', '<pad>'] ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다', '<end>'] ['텐서플로우는', '매우', '빠르게', '변화한다', '<end>']