# after https://github.com/hans/ipython-notebooks/blob/master/tf/TF%20tutorial.ipynb
import numpy as np
import tensorflow as tf
%matplotlib inline
import matplotlib.pyplot as plt
import tempfile
logdir = tempfile.mkdtemp()
print(logdir)
/tmp/tmpmn29ne8n
sess = tf.InteractiveSession()
seq_length = 5 # number of timesteps
batch_size = 64
vocab_size = 7
embedding_size = 50
state_size = 100
# tensors are input as a list of size (number of timesteps)
enc_inp = [tf.placeholder(tf.int32, shape=(None,), name="inp%i" % t) for t in range(seq_length)]
labels = [tf.placeholder(tf.int32, shape=(None,), name="labels%i" % t) for t in range(seq_length)]
enc_inp
[<tf.Tensor 'inp0:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp1:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp2:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp3:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp4:0' shape=(?,) dtype=int32>]
weights = [tf.ones_like(labels_t, dtype=tf.float32) for labels_t in labels]
weights
[<tf.Tensor 'ones_like:0' shape=(?,) dtype=float32>, <tf.Tensor 'ones_like_1:0' shape=(?,) dtype=float32>, <tf.Tensor 'ones_like_2:0' shape=(?,) dtype=float32>, <tf.Tensor 'ones_like_3:0' shape=(?,) dtype=float32>, <tf.Tensor 'ones_like_4:0' shape=(?,) dtype=float32>]
# Decoder input: prepend some "GO" token and drop the final token of the encoder input
dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name="GO")] + enc_inp[:-1])
dec_inp
[<tf.Tensor 'GO:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp0:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp1:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp2:0' shape=(?,) dtype=int32>, <tf.Tensor 'inp3:0' shape=(?,) dtype=int32>]
# Initial memory value for recurrence.
prev_mem = tf.zeros((batch_size, state_size))
prev_mem
<tf.Tensor 'zeros:0' shape=(64, 100) dtype=float32>
We can use different kinds of RNN cell for seq2seq.
cell = tf.nn.rnn_cell.GRUCell(state_size)
tf.nn.seq2seq.embedding_rnn_seq2seq:
Embedding RNN sequence-to-sequence model.
This model first embeds encoder_inputs by a newly created embedding (of shape [num_encoder_symbols x input_size]). Then it runs an RNN to encode embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs by another newly created embedding (of shape [num_decoder_symbols x input_size]). Then it runs RNN decoder, initialized with the last encoder state, on embedded decoder_inputs.
# inputs will be embedded, so have to specify maximum number of symbols that can appear (vocab_size)
dec_outputs, dec_state = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, vocab_size, vocab_size, embedding_size)
dec_outputs, dec_state
([<tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/OutputProjectionWrapper/add:0' shape=(?, 7) dtype=float32>, <tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/OutputProjectionWrapper_1/add:0' shape=(?, 7) dtype=float32>, <tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/OutputProjectionWrapper_2/add:0' shape=(?, 7) dtype=float32>, <tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/OutputProjectionWrapper_3/add:0' shape=(?, 7) dtype=float32>, <tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/OutputProjectionWrapper_4/add:0' shape=(?, 7) dtype=float32>], <tf.Tensor 'embedding_rnn_seq2seq/embedding_rnn_decoder/rnn_decoder/GRUCell_4/add:0' shape=(?, 100) dtype=float32>)
loss = tf.nn.seq2seq.sequence_loss(dec_outputs, labels, weights, vocab_size)
tf.scalar_summary("loss", loss)
<tf.Tensor 'ScalarSummary:0' shape=() dtype=string>
magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_state[1])))
tf.scalar_summary("magnitude at t=1", magnitude)
<tf.Tensor 'ScalarSummary_1:0' shape=() dtype=string>
summary_op = tf.merge_all_summaries()
learning_rate = 0.05
momentum = 0.9
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
train_op = optimizer.minimize(loss)
logdir = tempfile.mkdtemp()
print(logdir)
summary_writer = tf.train.SummaryWriter(logdir, sess.graph)
/tmp/tmpo_41xu3j
sess.run(tf.initialize_all_variables())
Generate input:
X = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
for _ in range(batch_size)]
Y = X[:]
X[:5]
[array([3, 4, 2, 6, 5]), array([5, 3, 4, 6, 1]), array([4, 2, 6, 5, 1]), array([2, 4, 5, 6, 1]), array([5, 4, 1, 2, 3])]
X = np.array(X).T
Y = np.array(Y).T
X[:5]
array([[3, 5, 4, 2, 5, 3, 1, 2, 3, 3, 0, 6, 4, 0, 5, 1, 4, 4, 1, 1, 0, 3, 4, 3, 6, 4, 4, 0, 0, 2, 6, 5, 3, 0, 5, 2, 4, 5, 1, 3, 5, 6, 2, 0, 2, 3, 0, 6, 2, 6, 1, 1, 4, 5, 4, 0, 0, 2, 3, 0, 5, 2, 1, 0], [4, 3, 2, 4, 4, 1, 2, 3, 0, 1, 1, 4, 5, 2, 0, 5, 0, 2, 4, 6, 4, 5, 6, 1, 2, 2, 3, 1, 5, 5, 2, 2, 5, 3, 6, 1, 5, 3, 5, 1, 0, 1, 1, 5, 5, 0, 5, 2, 5, 4, 5, 6, 2, 6, 1, 2, 2, 4, 0, 2, 6, 1, 3, 3], [2, 4, 6, 5, 1, 0, 3, 1, 1, 2, 4, 0, 1, 4, 1, 6, 1, 3, 2, 5, 1, 4, 3, 2, 1, 3, 2, 3, 6, 1, 1, 0, 4, 4, 2, 6, 1, 4, 0, 2, 1, 0, 3, 4, 3, 6, 4, 4, 6, 5, 4, 5, 1, 4, 3, 5, 5, 0, 6, 1, 3, 5, 4, 4], [6, 6, 5, 6, 2, 6, 5, 5, 6, 5, 5, 2, 6, 6, 4, 4, 3, 1, 6, 2, 5, 6, 0, 5, 0, 6, 0, 2, 4, 3, 3, 6, 2, 2, 3, 3, 3, 1, 6, 5, 4, 2, 4, 3, 0, 2, 2, 5, 3, 0, 0, 4, 6, 3, 2, 6, 3, 5, 2, 5, 1, 3, 2, 1], [5, 1, 1, 1, 3, 4, 6, 4, 5, 4, 6, 1, 0, 3, 2, 0, 6, 6, 0, 4, 2, 2, 2, 0, 3, 0, 6, 4, 1, 0, 4, 4, 1, 5, 1, 0, 0, 0, 4, 4, 2, 3, 0, 1, 1, 4, 6, 1, 0, 2, 3, 3, 5, 1, 5, 3, 1, 1, 1, 4, 4, 6, 6, 5]])
[X[t] for t in range(seq_length)]
[array([3, 5, 4, 2, 5, 3, 1, 2, 3, 3, 0, 6, 4, 0, 5, 1, 4, 4, 1, 1, 0, 3, 4, 3, 6, 4, 4, 0, 0, 2, 6, 5, 3, 0, 5, 2, 4, 5, 1, 3, 5, 6, 2, 0, 2, 3, 0, 6, 2, 6, 1, 1, 4, 5, 4, 0, 0, 2, 3, 0, 5, 2, 1, 0]), array([4, 3, 2, 4, 4, 1, 2, 3, 0, 1, 1, 4, 5, 2, 0, 5, 0, 2, 4, 6, 4, 5, 6, 1, 2, 2, 3, 1, 5, 5, 2, 2, 5, 3, 6, 1, 5, 3, 5, 1, 0, 1, 1, 5, 5, 0, 5, 2, 5, 4, 5, 6, 2, 6, 1, 2, 2, 4, 0, 2, 6, 1, 3, 3]), array([2, 4, 6, 5, 1, 0, 3, 1, 1, 2, 4, 0, 1, 4, 1, 6, 1, 3, 2, 5, 1, 4, 3, 2, 1, 3, 2, 3, 6, 1, 1, 0, 4, 4, 2, 6, 1, 4, 0, 2, 1, 0, 3, 4, 3, 6, 4, 4, 6, 5, 4, 5, 1, 4, 3, 5, 5, 0, 6, 1, 3, 5, 4, 4]), array([6, 6, 5, 6, 2, 6, 5, 5, 6, 5, 5, 2, 6, 6, 4, 4, 3, 1, 6, 2, 5, 6, 0, 5, 0, 6, 0, 2, 4, 3, 3, 6, 2, 2, 3, 3, 3, 1, 6, 5, 4, 2, 4, 3, 0, 2, 2, 5, 3, 0, 0, 4, 6, 3, 2, 6, 3, 5, 2, 5, 1, 3, 2, 1]), array([5, 1, 1, 1, 3, 4, 6, 4, 5, 4, 6, 1, 0, 3, 2, 0, 6, 6, 0, 4, 2, 2, 2, 0, 3, 0, 6, 4, 1, 0, 4, 4, 1, 5, 1, 0, 0, 0, 4, 4, 2, 3, 0, 1, 1, 4, 6, 1, 0, 2, 3, 3, 5, 1, 5, 3, 1, 1, 1, 4, 4, 6, 6, 5])]
Feed input:
feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})
One training step:
_, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)
loss_t, summary
(1.9546013, b'\n\x0b\n\x04loss\x15`0\xfa?\n\x17\n\x10magnitude at t=1\x15/\xc5\x8a?')
X_batch = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
for _ in range(10)]
X_batch
[array([5, 0, 4, 2, 1]), array([1, 2, 5, 3, 0]), array([3, 0, 5, 1, 4]), array([6, 3, 5, 2, 1]), array([6, 4, 1, 5, 2]), array([1, 5, 3, 0, 6]), array([4, 6, 0, 2, 3]), array([6, 4, 0, 5, 1]), array([0, 2, 1, 6, 4]), array([4, 3, 1, 5, 0])]
X_batch = np.array(X_batch).T
X_batch
array([[5, 1, 3, 6, 6, 1, 4, 6, 0, 4], [0, 2, 0, 3, 4, 5, 6, 4, 2, 3], [4, 5, 5, 5, 1, 3, 0, 0, 1, 1], [2, 3, 1, 2, 5, 0, 2, 5, 6, 5], [1, 0, 4, 1, 2, 6, 3, 1, 4, 0]])
feed_dict = {enc_inp[t]: X_batch[t] for t in range(seq_length)}
dec_outputs_batch = sess.run(dec_outputs, feed_dict)
dec_outputs_batch
[array([[ 0.17844655, 0.11353172, -0.18144946, 0.01085662, 0.2627669 , -0.19861746, -0.08941379], [ 0.21098135, -0.08187042, 0.01305156, -0.05247057, 0.07250485, -0.23508257, -0.2811504 ], [ 0.13883567, 0.09521072, -0.11679886, 0.00680578, 0.16948652, -0.12717277, -0.0371674 ], [ 0.21993272, 0.00940304, -0.16575341, -0.16260277, 0.1767858 , -0.18685289, -0.08371995], [ 0.10060885, 0.01945375, -0.21006839, -0.14021412, -0.01953002, -0.18472196, -0.08986312], [ 0.17035802, -0.07102758, 0.0302645 , -0.25397092, -0.00265401, -0.12531452, -0.2037717 ], [ 0.24822755, -0.21595982, 0.00301015, -0.27651039, 0.10579222, -0.24226177, -0.21743613], [ 0.27941915, 0.13427791, -0.13733003, -0.05824878, 0.17526907, -0.10152674, -0.1523411 ], [ 0.04663948, -0.0393581 , -0.08266062, -0.11853192, 0.05972272, -0.11571028, -0.0639744 ], [ 0.18015689, -0.02721097, -0.00587304, -0.02247566, 0.12168466, -0.14741473, -0.15037411]], dtype=float32), array([[ 0.10395906, 0.18637255, -0.11520418, 0.0176604 , 0.10534781, -0.07860883, 0.01924768], [ 0.15846533, -0.07039611, -0.04755091, -0.07138188, -0.02692247, -0.18470216, -0.17578526], [ 0.16245708, 0.15115945, -0.045459 , 0.0340042 , 0.1339983 , -0.09590427, 0.03113824], [ 0.10051248, -0.02949939, -0.06673641, -0.13614778, 0.16801459, -0.17466651, -0.063809 ], [ 0.03787023, -0.01964972, -0.11155554, -0.09965882, 0.03461789, -0.16733888, -0.08309884], [ 0.09131334, -0.06725299, -0.04503515, -0.23164834, -0.08556231, -0.11044674, -0.13894588], [ 0.16687544, -0.14045513, -0.03331609, -0.10184428, 0.17324263, -0.1585528 , -0.05186675], [ 0.18770847, 0.07122251, -0.04060577, -0.03737558, 0.17788059, -0.08480155, -0.10424838], [ 0.1132549 , -0.00362585, 0.04683149, 0.01941368, 0.17256692, 0.04148631, -0.14065626], [ 0.12920758, 0.01740728, -0.02748516, 0.07680587, 0.19532834, -0.09230802, -0.02468909]], dtype=float32), array([[ 1.33050695e-01, 1.48067713e-01, -9.75858513e-03, 1.31673515e-01, 1.92209393e-01, 6.93723485e-02, -4.91486937e-02], [ -5.59313111e-02, 5.76260835e-02, -6.22149697e-03, -5.92892915e-02, -6.83557987e-02, -1.39937550e-01, -1.34125471e-01], [ 2.13978752e-01, 1.23260379e-01, 4.88165878e-02, 1.21520363e-01, 2.33814687e-01, 5.97248301e-02, -5.82022667e-02], [ 9.52694789e-02, 5.56701086e-02, -4.42465693e-02, -8.04473609e-02, 9.91110578e-02, -1.50191337e-01, 8.31928663e-03], [ -1.48642331e-03, 1.77686252e-02, -1.26366407e-01, -1.18158665e-02, 1.07669294e-01, -1.01451002e-01, 1.92329716e-02], [ 4.51667160e-02, 3.05926315e-02, -4.32675816e-02, -1.06809497e-01, -1.46296367e-01, -6.12212159e-02, -2.95996219e-02], [ 1.39960676e-01, -1.28353313e-01, 3.92415486e-02, -3.38840075e-02, 1.64977491e-01, -1.28964454e-01, -8.23383965e-03], [ 1.13442115e-01, 8.32194015e-02, -5.56406938e-02, 2.53340248e-02, 2.11977690e-01, -3.01985200e-02, 7.91771710e-03], [ -1.17359094e-01, 9.14113894e-02, 1.13254704e-01, -3.58526148e-02, 9.18745771e-02, 4.33049612e-02, -9.80217829e-02], [ 1.86070725e-01, 8.21213499e-02, -1.28050480e-04, 1.13477461e-01, 1.50830865e-01, -8.51025060e-02, 3.86973992e-02]], dtype=float32), array([[ 0.04973944, 0.10831716, -0.00375871, 0.17339239, 0.22739363, 0.06614812, 0.05624339], [-0.04072684, 0.14192866, -0.01291926, -0.00893061, -0.14943664, -0.08095054, -0.00498922], [ 0.1481124 , 0.17895575, 0.07431188, 0.10259738, 0.09361922, 0.08516876, 0.05082116], [ 0.05775284, 0.13802256, -0.04122018, -0.05294784, -0.02876745, -0.0752691 , 0.09462182], [ 0.00873535, -0.01149929, -0.12302497, -0.06627826, -0.02505408, -0.05816741, 0.03803293], [ 0.05197376, 0.07711638, -0.02642946, -0.01482341, -0.1043431 , -0.06400553, 0.02498893], [ 0.2129983 , -0.06768389, 0.08495383, 0.10012042, 0.22794855, 0.02708369, -0.08005156], [ 0.19821872, 0.08864162, 0.02410168, 0.11796612, 0.26989272, 0.10462284, -0.07330685], [-0.09578598, 0.02733124, 0.08310014, -0.09925178, -0.03786328, 0.05727703, -0.04634201], [ 0.1499618 , 0.02167371, -0.02839148, 0.04336146, 0.03068283, -0.03333036, 0.056478 ]], dtype=float32), array([[-0.13720484, 0.17640325, 0.06879292, 0.07539334, 0.12553324, 0.05548202, 0.04008906], [ 0.00208345, 0.16644646, 0.00865782, 0.03311532, -0.12878664, -0.0752397 , 0.04306988], [ 0.07093671, 0.08025096, 0.04515825, 0.02599708, -0.02427215, 0.0889031 , 0.05317469], [-0.17283508, 0.18637198, 0.01248625, -0.08382478, -0.0740777 , -0.05373616, 0.04355513], [ 0.01205936, 0.07759147, -0.10704872, 0.00033424, -0.11591723, -0.00552052, 0.11702637], [ 0.09295554, 0.04739372, 0.02949173, 0.11420977, 0.05874372, 0.07074399, -0.06921744], [-0.01032462, 0.05197939, 0.12150411, 0.03734748, 0.10758823, 0.03083727, -0.04855074], [ 0.16330148, 0.16301754, 0.03916627, 0.10394222, 0.10825945, 0.11118951, 0.04324638], [-0.07732442, -0.03749544, 0.13659573, -0.06185998, 0.01066446, 0.02512036, -0.01716892], [ 0.10852097, 0.08285542, -0.02328821, 0.10114764, -0.06120555, 0.02239372, 0.13511136]], dtype=float32)]
[logits_t.argmax(axis=1) for logits_t in dec_outputs_batch]
[array([4, 0, 4, 0, 0, 0, 0, 0, 4, 0]), array([1, 0, 0, 4, 0, 0, 4, 0, 4, 4]), array([4, 1, 4, 4, 4, 0, 4, 4, 2, 0]), array([4, 1, 1, 1, 6, 1, 4, 4, 2, 0]), array([1, 1, 5, 1, 6, 3, 2, 0, 2, 6])]
def train_batch(batch_size):
X = [np.random.choice(vocab_size, size=(seq_length,), replace=False) for _ in range(batch_size)]
Y = X[:]
# Dimshuffle to seq_len * batch_size
X = np.array(X).T
Y = np.array(Y).T
feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})
_, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)
return loss_t, summary
for t in range(500):
loss_t, summary = train_batch(batch_size)
summary_writer.add_summary(summary, t)
summary_writer.flush()