In [24]:
%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
In [25]:
# Download mnist data

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data')
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [26]:
# Create model inputs

def model_inputs(discriminator_dim, generator_dim):
    discriminator_inputs = tf.placeholder(
        tf.float32,
        shape=[None, discriminator_dim],
        name='discriminator_input'
    )
    generator_inputs = tf.placeholder(
        tf.float32,
        shape=[None, generator_dim],
        name='generator_input'
    )
    return discriminator_inputs, generator_inputs
In [27]:
# Create generator

def generator(_input, out_dim, num_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('generator', reuse=reuse):
        hidden_layer = tf.layers.dense(_input, num_units)
        leaky_relu = tf.maximum(hidden_layer * alpha, hidden_layer)
        
        logits = tf.layers.dense(leaky_relu, out_dim)
        return tf.tanh(logits)
In [28]:
# Create discriminator

def discriminator(_input, num_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator', reuse=reuse):
        hidden_layer = tf.layers.dense(_input, num_units)
        leaky_relu = tf.maximum(hidden_layer * alpha, hidden_layer)
        
        logits = tf.layers.dense(leaky_relu, 1)
        return tf.sigmoid(logits), logits
In [41]:
# Set hyperparameters

discriminator_input_size = 784
generator_input_size = 100

generator_hidden_size = discriminator_hidden_size = 128

alpha = 0.01

learning_rate = 0.002

smooth = 0.1
In [42]:
# Building the network

tf.reset_default_graph()
## Create input placeholders
discriminator_input, generator_input = model_inputs(
    discriminator_input_size, generator_input_size)


## Build the model

generator_model = generator(generator_input, discriminator_input_size)

discriminator_model_real, discriminator_logits_real = discriminator(discriminator_input)
discriminator_model_fake, discriminator_logits_fake = discriminator(generator_model, reuse=True)
In [43]:
# Losses

## Discriminator loss real

discriminator_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(
    logits=discriminator_logits_real,
    labels=tf.ones_like(discriminator_logits_real) * (1 - smooth)
)
discriminator_loss_real = tf.reduce_mean(discriminator_loss_real)

## Discriminator loss fake

discriminator_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(
    logits=discriminator_logits_fake,
    labels=tf.zeros_like(discriminator_logits_real)
)
discriminator_loss_fake = tf.reduce_mean(discriminator_loss_fake)

## Discriminator loss

discriminator_loss = discriminator_loss_real + discriminator_loss_fake

## Generator loss

generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(
    logits=discriminator_logits_fake,
    labels=tf.ones_like(discriminator_logits_fake)
)

generator_loss = tf.reduce_mean(generator_loss)
In [47]:
# Optimizers

## Get the trainable variables to split into generator and discriminator vars
train_variables = tf.trainable_variables()

generator_variables = [var for var in train_variables if var.name.startswith('generator')]
discriminator_variables = [var for var in train_variables if var.name.startswith('discriminator')]

## Create optimizers with var lists

generator_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(
    generator_loss, var_list=generator_variables)

discriminator_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(
    discriminator_loss, var_list=discriminator_variables)
In [48]:
!mkdir checkpoints
mkdir: cannot create directory ‘checkpoints’: File exists
In [51]:
# Training

batch_size = 100
epochs = 100
samples = []
losses = []

# Only save generator variables
saver = tf.train.Saver(var_list=generator_variables)

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    for epoch in range(epochs):
        for i in range(mnist.train.num_examples // batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images * 2 - 1
            
            # Sample random noise for generator
            batch_generator = np.random.uniform(-1, 1, size=(batch_size, generator_input_size))
            
            # Run optimizers
            session.run(
                discriminator_optimizer,
                feed_dict={
                    discriminator_input: batch_images,
                    generator_input: batch_generator,
                }
            )
            
            session.run(
                generator_optimizer,
                feed_dict={generator_input: batch_generator}
            )
            
        # At the end of each epoch, get the losses and print them out
        train_loss_discriminator = session.run(
            discriminator_loss,
            feed_dict={
                discriminator_input: batch_images,
                generator_input: batch_generator,
            }
        )
        train_loss_generator = generator_loss.eval({generator_input: batch_generator})
        
        print("Epoch {}/{}...".format(epoch + 1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_discriminator),
              "Generator Loss: {:.4f}".format(train_loss_generator))
        
        # Save losses to view after training
        losses.append((train_loss_discriminator, train_loss_generator))
        
        # Sample from generator as we're training for viewing afterwards
        sample_generator = np.random.uniform(-1, 1, size=(16, generator_input_size))
        generator_samples = session.run(
            generator(generator_input, discriminator_input_size, reuse=True),
            feed_dict={generator_input: sample_generator}
        )
        samples.append(generator_samples)
        saver.save(session, './checkpoints/generator.ckpt')
        
with open('train_samples.pkl', 'wb') as _file:
    pkl.dump(samples, _file)
    
('Epoch 1/100...', 'Discriminator Loss: 0.3610...', 'Generator Loss: 3.7139')
('Epoch 2/100...', 'Discriminator Loss: 0.3537...', 'Generator Loss: 3.9365')
('Epoch 3/100...', 'Discriminator Loss: 0.4888...', 'Generator Loss: 3.7891')
('Epoch 4/100...', 'Discriminator Loss: 0.4025...', 'Generator Loss: 3.9308')
('Epoch 5/100...', 'Discriminator Loss: 0.9150...', 'Generator Loss: 3.6636')
('Epoch 6/100...', 'Discriminator Loss: 1.5940...', 'Generator Loss: 1.7054')
('Epoch 7/100...', 'Discriminator Loss: 1.1450...', 'Generator Loss: 5.1282')
('Epoch 8/100...', 'Discriminator Loss: 1.2245...', 'Generator Loss: 1.8742')
('Epoch 9/100...', 'Discriminator Loss: 1.0641...', 'Generator Loss: 1.2291')
('Epoch 10/100...', 'Discriminator Loss: 1.0782...', 'Generator Loss: 2.1048')
('Epoch 11/100...', 'Discriminator Loss: 1.0230...', 'Generator Loss: 2.1220')
('Epoch 12/100...', 'Discriminator Loss: 0.8871...', 'Generator Loss: 2.4581')
('Epoch 13/100...', 'Discriminator Loss: 1.0302...', 'Generator Loss: 2.2410')
('Epoch 14/100...', 'Discriminator Loss: 1.3629...', 'Generator Loss: 1.5072')
('Epoch 15/100...', 'Discriminator Loss: 1.4571...', 'Generator Loss: 2.5060')
('Epoch 16/100...', 'Discriminator Loss: 1.3771...', 'Generator Loss: 1.5331')
('Epoch 17/100...', 'Discriminator Loss: 2.1457...', 'Generator Loss: 0.8336')
('Epoch 18/100...', 'Discriminator Loss: 0.8136...', 'Generator Loss: 2.1608')
('Epoch 19/100...', 'Discriminator Loss: 1.0390...', 'Generator Loss: 1.8266')
('Epoch 20/100...', 'Discriminator Loss: 1.2297...', 'Generator Loss: 1.6069')
('Epoch 21/100...', 'Discriminator Loss: 1.1496...', 'Generator Loss: 1.4269')
('Epoch 22/100...', 'Discriminator Loss: 0.8517...', 'Generator Loss: 2.0660')
('Epoch 23/100...', 'Discriminator Loss: 0.8255...', 'Generator Loss: 1.9235')
('Epoch 24/100...', 'Discriminator Loss: 1.1326...', 'Generator Loss: 1.8619')
('Epoch 25/100...', 'Discriminator Loss: 1.0850...', 'Generator Loss: 2.3271')
('Epoch 26/100...', 'Discriminator Loss: 0.9823...', 'Generator Loss: 2.0439')
('Epoch 27/100...', 'Discriminator Loss: 0.8183...', 'Generator Loss: 2.3446')
('Epoch 28/100...', 'Discriminator Loss: 0.9516...', 'Generator Loss: 1.6764')
('Epoch 29/100...', 'Discriminator Loss: 0.9301...', 'Generator Loss: 2.0445')
('Epoch 30/100...', 'Discriminator Loss: 0.8673...', 'Generator Loss: 1.9762')
('Epoch 31/100...', 'Discriminator Loss: 1.1893...', 'Generator Loss: 1.4226')
('Epoch 32/100...', 'Discriminator Loss: 0.8114...', 'Generator Loss: 2.3016')
('Epoch 33/100...', 'Discriminator Loss: 1.0492...', 'Generator Loss: 1.8329')
('Epoch 34/100...', 'Discriminator Loss: 0.9259...', 'Generator Loss: 2.8954')
('Epoch 35/100...', 'Discriminator Loss: 1.0282...', 'Generator Loss: 2.0221')
('Epoch 36/100...', 'Discriminator Loss: 0.8891...', 'Generator Loss: 2.0999')
('Epoch 37/100...', 'Discriminator Loss: 0.8232...', 'Generator Loss: 2.6474')
('Epoch 38/100...', 'Discriminator Loss: 1.0443...', 'Generator Loss: 2.1272')
('Epoch 39/100...', 'Discriminator Loss: 0.9505...', 'Generator Loss: 2.1103')
('Epoch 40/100...', 'Discriminator Loss: 0.7978...', 'Generator Loss: 2.3140')
('Epoch 41/100...', 'Discriminator Loss: 1.0642...', 'Generator Loss: 1.9769')
('Epoch 42/100...', 'Discriminator Loss: 0.8561...', 'Generator Loss: 2.2247')
('Epoch 43/100...', 'Discriminator Loss: 1.0182...', 'Generator Loss: 2.0321')
('Epoch 44/100...', 'Discriminator Loss: 0.9855...', 'Generator Loss: 1.5972')
('Epoch 45/100...', 'Discriminator Loss: 1.4146...', 'Generator Loss: 1.1930')
('Epoch 46/100...', 'Discriminator Loss: 1.2144...', 'Generator Loss: 2.7811')
('Epoch 47/100...', 'Discriminator Loss: 0.8786...', 'Generator Loss: 2.1741')
('Epoch 48/100...', 'Discriminator Loss: 1.0341...', 'Generator Loss: 1.9743')
('Epoch 49/100...', 'Discriminator Loss: 1.1616...', 'Generator Loss: 1.8697')
('Epoch 50/100...', 'Discriminator Loss: 0.9131...', 'Generator Loss: 1.8982')
('Epoch 51/100...', 'Discriminator Loss: 1.1300...', 'Generator Loss: 1.8999')
('Epoch 52/100...', 'Discriminator Loss: 1.0230...', 'Generator Loss: 1.7355')
('Epoch 53/100...', 'Discriminator Loss: 0.9176...', 'Generator Loss: 1.6912')
('Epoch 54/100...', 'Discriminator Loss: 0.8761...', 'Generator Loss: 2.3600')
('Epoch 55/100...', 'Discriminator Loss: 0.9307...', 'Generator Loss: 1.7238')
('Epoch 56/100...', 'Discriminator Loss: 0.7768...', 'Generator Loss: 2.4752')
('Epoch 57/100...', 'Discriminator Loss: 0.9234...', 'Generator Loss: 2.0741')
('Epoch 58/100...', 'Discriminator Loss: 0.9194...', 'Generator Loss: 1.7478')
('Epoch 59/100...', 'Discriminator Loss: 0.9222...', 'Generator Loss: 1.8859')
('Epoch 60/100...', 'Discriminator Loss: 1.0376...', 'Generator Loss: 1.8256')
('Epoch 61/100...', 'Discriminator Loss: 0.9389...', 'Generator Loss: 2.2363')
('Epoch 62/100...', 'Discriminator Loss: 1.0775...', 'Generator Loss: 1.8134')
('Epoch 63/100...', 'Discriminator Loss: 1.0164...', 'Generator Loss: 1.4308')
('Epoch 64/100...', 'Discriminator Loss: 0.9118...', 'Generator Loss: 2.2420')
('Epoch 65/100...', 'Discriminator Loss: 0.9385...', 'Generator Loss: 1.8752')
('Epoch 66/100...', 'Discriminator Loss: 0.8307...', 'Generator Loss: 1.9120')
('Epoch 67/100...', 'Discriminator Loss: 0.8071...', 'Generator Loss: 2.5180')
('Epoch 68/100...', 'Discriminator Loss: 0.9992...', 'Generator Loss: 2.0578')
('Epoch 69/100...', 'Discriminator Loss: 1.0193...', 'Generator Loss: 2.0279')
('Epoch 70/100...', 'Discriminator Loss: 0.8511...', 'Generator Loss: 1.7309')
('Epoch 71/100...', 'Discriminator Loss: 0.9783...', 'Generator Loss: 1.8921')
('Epoch 72/100...', 'Discriminator Loss: 0.9808...', 'Generator Loss: 1.8034')
('Epoch 73/100...', 'Discriminator Loss: 0.9443...', 'Generator Loss: 1.7591')
('Epoch 74/100...', 'Discriminator Loss: 0.9698...', 'Generator Loss: 1.5257')
('Epoch 75/100...', 'Discriminator Loss: 0.9707...', 'Generator Loss: 1.6312')
('Epoch 76/100...', 'Discriminator Loss: 0.8781...', 'Generator Loss: 2.6984')
('Epoch 77/100...', 'Discriminator Loss: 0.9168...', 'Generator Loss: 1.9443')
('Epoch 78/100...', 'Discriminator Loss: 0.9764...', 'Generator Loss: 1.8984')
('Epoch 79/100...', 'Discriminator Loss: 0.9344...', 'Generator Loss: 1.8933')
('Epoch 80/100...', 'Discriminator Loss: 1.1613...', 'Generator Loss: 1.9193')
('Epoch 81/100...', 'Discriminator Loss: 1.0227...', 'Generator Loss: 1.7076')
('Epoch 82/100...', 'Discriminator Loss: 0.8183...', 'Generator Loss: 1.8793')
('Epoch 83/100...', 'Discriminator Loss: 0.9112...', 'Generator Loss: 1.8770')
('Epoch 84/100...', 'Discriminator Loss: 0.9264...', 'Generator Loss: 2.1068')
('Epoch 85/100...', 'Discriminator Loss: 0.8174...', 'Generator Loss: 2.0643')
('Epoch 86/100...', 'Discriminator Loss: 0.8852...', 'Generator Loss: 2.1594')
('Epoch 87/100...', 'Discriminator Loss: 1.0170...', 'Generator Loss: 1.5330')
('Epoch 88/100...', 'Discriminator Loss: 0.9006...', 'Generator Loss: 1.9815')
('Epoch 89/100...', 'Discriminator Loss: 1.0992...', 'Generator Loss: 1.5161')
('Epoch 90/100...', 'Discriminator Loss: 1.0275...', 'Generator Loss: 1.7076')
('Epoch 91/100...', 'Discriminator Loss: 0.8108...', 'Generator Loss: 1.6832')
('Epoch 92/100...', 'Discriminator Loss: 0.9183...', 'Generator Loss: 1.8923')
('Epoch 93/100...', 'Discriminator Loss: 1.0284...', 'Generator Loss: 2.6486')
('Epoch 94/100...', 'Discriminator Loss: 0.8037...', 'Generator Loss: 1.9870')
('Epoch 95/100...', 'Discriminator Loss: 0.9714...', 'Generator Loss: 1.8809')
('Epoch 96/100...', 'Discriminator Loss: 0.9094...', 'Generator Loss: 1.9679')
('Epoch 97/100...', 'Discriminator Loss: 0.8892...', 'Generator Loss: 1.7226')
('Epoch 98/100...', 'Discriminator Loss: 1.0335...', 'Generator Loss: 1.8607')
('Epoch 99/100...', 'Discriminator Loss: 0.8608...', 'Generator Loss: 1.9716')
('Epoch 100/100...', 'Discriminator Loss: 0.9031...', 'Generator Loss: 1.9182')
In [52]:
# Training loss

fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()
Out[52]:
<matplotlib.legend.Legend at 0x7f9b495f2050>
In [53]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes
In [56]:
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

view_samples(-1, samples)
Out[56]:
(<matplotlib.figure.Figure at 0x7f9b4a622810>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4a4ef0d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b48f45950>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4b262690>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b48477390>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b48e60510>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b483989d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4ab98890>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4b551250>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b5573c090>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b49366f10>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b55764ad0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b49cb5c50>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b5578e110>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4a32c1d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b49ab2d10>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b4ab0f2d0>]], dtype=object))
In [57]:
# Sample from every 10 epochs
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
In [60]:
# Sampling from the generator

saver = tf.train.Saver(var_list=generator_variables)
with tf.Session() as session:
    saver.restore(session, tf.train.latest_checkpoint('checkpoints'))
    sample_generator = np.random.uniform(-1, 1, size=(16, generator_input_size))
    samples = session.run(
                   generator(generator_input, discriminator_input_size, reuse=True),
                   feed_dict={generator_input: sample_generator})
view_samples(0, [samples])
Out[60]:
(<matplotlib.figure.Figure at 0x7f9b56525250>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b42427250>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b429ddf10>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b556ace90>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b41b18610>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b422b4910>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b423b3990>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b42318110>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b42255290>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b422c3050>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b421b77d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b420bc350>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b420a18d0>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f9b42025550>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b41f8c0d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b41f01e90>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f9b41e86a10>]], dtype=object))