#!/usr/bin/env python # coding: utf-8 # In[1]: # after https://github.com/hans/ipython-notebooks/blob/master/tf/TF%20tutorial.ipynb # In[2]: import numpy as np import tensorflow as tf get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt import tempfile logdir = tempfile.mkdtemp() print(logdir) # In[3]: sess = tf.InteractiveSession() # In[4]: seq_length = 5 # number of timesteps batch_size = 64 vocab_size = 7 embedding_size = 50 state_size = 100 # In[5]: # tensors are input as a list of size (number of timesteps) enc_inp = [tf.placeholder(tf.int32, shape=(None,), name="inp%i" % t) for t in range(seq_length)] labels = [tf.placeholder(tf.int32, shape=(None,), name="labels%i" % t) for t in range(seq_length)] enc_inp # In[6]: weights = [tf.ones_like(labels_t, dtype=tf.float32) for labels_t in labels] weights # In[7]: # Decoder input: prepend some "GO" token and drop the final token of the encoder input dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name="GO")] + enc_inp[:-1]) dec_inp # In[8]: # Initial memory value for recurrence. prev_mem = tf.zeros((batch_size, state_size)) prev_mem # We can use different kinds of RNN cell for seq2seq. # In[30]: cell = tf.nn.rnn_cell.GRUCell(state_size) # tf.nn.seq2seq.embedding_rnn_seq2seq: # # > Embedding RNN sequence-to-sequence model. # This model first embeds encoder_inputs by a newly created embedding (of shape # [num_encoder_symbols x input_size]). Then it runs an RNN to encode # embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs # by another newly created embedding (of shape [num_decoder_symbols x # input_size]). Then it runs RNN decoder, initialized with the last # encoder state, on embedded decoder_inputs. # In[10]: # inputs will be embedded, so have to specify maximum number of symbols that can appear (vocab_size) dec_outputs, dec_state = tf.nn.seq2seq.embedding_rnn_seq2seq( enc_inp, dec_inp, cell, vocab_size, vocab_size, embedding_size) dec_outputs, dec_state # In[11]: loss = tf.nn.seq2seq.sequence_loss(dec_outputs, labels, weights, vocab_size) # In[12]: tf.scalar_summary("loss", loss) # In[13]: magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_state[1]))) tf.scalar_summary("magnitude at t=1", magnitude) # In[14]: summary_op = tf.merge_all_summaries() # In[15]: learning_rate = 0.05 momentum = 0.9 optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) train_op = optimizer.minimize(loss) # In[16]: logdir = tempfile.mkdtemp() print(logdir) summary_writer = tf.train.SummaryWriter(logdir, sess.graph) # In[17]: sess.run(tf.initialize_all_variables()) # ### Train network, step-by-step # Generate input: # In[18]: X = [np.random.choice(vocab_size, size=(seq_length,), replace=False) for _ in range(batch_size)] Y = X[:] X[:5] # In[19]: X = np.array(X).T Y = np.array(Y).T X[:5] # In[20]: [X[t] for t in range(seq_length)] # Feed input: # In[21]: feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)} # In[22]: feed_dict.update({labels[t]: Y[t] for t in range(seq_length)}) # One training step: # In[23]: _, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict) loss_t, summary # ### Test case # In[24]: X_batch = [np.random.choice(vocab_size, size=(seq_length,), replace=False) for _ in range(10)] X_batch # In[25]: X_batch = np.array(X_batch).T X_batch # In[26]: feed_dict = {enc_inp[t]: X_batch[t] for t in range(seq_length)} dec_outputs_batch = sess.run(dec_outputs, feed_dict) dec_outputs_batch # In[27]: [logits_t.argmax(axis=1) for logits_t in dec_outputs_batch] # ### training function # In[28]: def train_batch(batch_size): X = [np.random.choice(vocab_size, size=(seq_length,), replace=False) for _ in range(batch_size)] Y = X[:] # Dimshuffle to seq_len * batch_size X = np.array(X).T Y = np.array(Y).T feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)} feed_dict.update({labels[t]: Y[t] for t in range(seq_length)}) _, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict) return loss_t, summary # In[29]: for t in range(500): loss_t, summary = train_batch(batch_size) summary_writer.add_summary(summary, t) summary_writer.flush() # In[ ]: # In[ ]: