Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p tensorflow
Sebastian Raschka 

CPython 3.7.3
IPython 7.6.1

tensorflow 1.13.1

Convolutional General Adversarial Networks

Implementation of General Adversarial Nets (GAN) where both the discriminator and generator have convolutional and deconvolutional layers, respectively. In this example, the GAN generator was trained to generate MNIST images.

Uses

  • samples from a random normal distribution (range [-1, 1])
  • dropout
  • leaky relus
  • batch normalization
  • separate batches for "fake" and "real" images (where the labels are 1 = real images, 0 = fake images)
  • MNIST images normalized to [-1, 1] range
  • generator with tanh output
In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pickle as pkl

tf.test.gpu_device_name()
Out[2]:
'/device:GPU:0'
In [3]:
### Abbreviatiuons
# dis_*: discriminator network
# gen_*: generator network

########################
### Helper functions
########################

def leaky_relu(x, alpha=0.0001):
    return tf.maximum(alpha * x, x)


########################
### DATASET
########################

mnist = input_data.read_data_sets('MNIST_data')


#########################
### SETTINGS
#########################

# Hyperparameters
learning_rate = 0.001
training_epochs = 50
batch_size = 64
dropout_rate = 0.5

# Architecture
dis_input_size = 784
gen_input_size = 100

# Other settings
print_interval = 200

#########################
### GRAPH DEFINITION
#########################

g = tf.Graph()
with g.as_default():
    
    # Placeholders for settings
    dropout = tf.placeholder(tf.float32, shape=None, name='dropout')
    is_training = tf.placeholder(tf.bool, shape=None, name='is_training')
    
    # Input data
    dis_x = tf.placeholder(tf.float32, shape=[None, dis_input_size],
                           name='discriminator_inputs')     
    gen_x = tf.placeholder(tf.float32, [None, gen_input_size],
                           name='generator_inputs')


    ##################
    # Generator Model
    ##################

    with tf.variable_scope('generator'):
        
        # 100 => 784 => 7x7x64
        gen_fc = tf.layers.dense(inputs=gen_x, units=3136,
                                 bias_initializer=None, # no bias required when using batch_norm
                                 activation=None)
        gen_fc = tf.layers.batch_normalization(gen_fc, training=is_training)
        gen_fc = leaky_relu(gen_fc)
        gen_fc = tf.reshape(gen_fc, (-1, 7, 7, 64))
        
        # 7x7x64 => 14x14x32
        deconv1 = tf.layers.conv2d_transpose(gen_fc, filters=32, 
                                             kernel_size=(3, 3), strides=(2, 2), 
                                             padding='same',
                                             bias_initializer=None,
                                             activation=None)
        deconv1 = tf.layers.batch_normalization(deconv1, training=is_training)
        deconv1 = leaky_relu(deconv1)     
        deconv1 = tf.layers.dropout(deconv1, rate=dropout_rate)
        
        # 14x14x32 => 28x28x16
        deconv2 = tf.layers.conv2d_transpose(deconv1, filters=16, 
                                             kernel_size=(3, 3), strides=(2, 2), 
                                             padding='same',
                                             bias_initializer=None,
                                             activation=None)
        deconv2 = tf.layers.batch_normalization(deconv2, training=is_training)
        deconv2 = leaky_relu(deconv2)     
        deconv2 = tf.layers.dropout(deconv2, rate=dropout_rate)
        
        # 28x28x16 => 28x28x8
        deconv3 = tf.layers.conv2d_transpose(deconv2, filters=8, 
                                             kernel_size=(3, 3), strides=(1, 1), 
                                             padding='same',
                                             bias_initializer=None,
                                             activation=None)
        deconv3 = tf.layers.batch_normalization(deconv3, training=is_training)
        deconv3 = leaky_relu(deconv3)     
        deconv3 = tf.layers.dropout(deconv3, rate=dropout_rate)
        
        # 28x28x8 => 28x28x1
        gen_logits = tf.layers.conv2d_transpose(deconv3, filters=1, 
                                                kernel_size=(3, 3), strides=(1, 1), 
                                                padding='same',
                                                bias_initializer=None,
                                                activation=None)
        gen_out = tf.tanh(gen_logits, 'generator_outputs')


    ######################
    # Discriminator Model
    ######################
    
    def build_discriminator_graph(input_x, reuse=None):

        with tf.variable_scope('discriminator', reuse=reuse):
            
            # 28x28x1 => 14x14x8
            conv_input = tf.reshape(input_x, (-1, 28, 28, 1))
            conv1 = tf.layers.conv2d(conv_input, filters=8, kernel_size=(3, 3),
                                     strides=(2, 2), padding='same',
                                     bias_initializer=None,
                                     activation=None)
            conv1 = tf.layers.batch_normalization(conv1, training=is_training)
            conv1 = leaky_relu(conv1)
            conv1 = tf.layers.dropout(conv1, rate=dropout_rate)
            
            # 14x14x8 => 7x7x32
            conv2 = tf.layers.conv2d(conv1, filters=32, kernel_size=(3, 3),
                                     strides=(2, 2), padding='same',
                                     bias_initializer=None,
                                     activation=None)
            conv2 = tf.layers.batch_normalization(conv2, training=is_training)
            conv2 = leaky_relu(conv2)
            conv2 = tf.layers.dropout(conv2, rate=dropout_rate)

            # fully connected layer
            fc_input = tf.reshape(conv2, (-1, 7*7*32))
            logits = tf.layers.dense(inputs=fc_input, units=1, activation=None)
            out = tf.sigmoid(logits)
            
        return logits, out    

    # Create a discriminator for real data and a discriminator for fake data
    dis_real_logits, dis_real_out = build_discriminator_graph(dis_x, reuse=False)
    dis_fake_logits, dis_fake_out = build_discriminator_graph(gen_out, reuse=True)


    #####################################
    # Generator and Discriminator Losses
    #####################################
    
    # Two discriminator cost components: loss on real data + loss on fake data
    # Real data has class label 1, fake data has class label 0
    dis_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real_logits, 
                                                            labels=tf.ones_like(dis_real_logits))
    dis_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits, 
                                                            labels=tf.zeros_like(dis_fake_logits))
    dis_cost = tf.add(tf.reduce_mean(dis_fake_loss), 
                      tf.reduce_mean(dis_real_loss), 
                      name='discriminator_cost')
 
    # Generator cost: difference between dis. prediction and label "1" for real images
    gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits,
                                                       labels=tf.ones_like(dis_fake_logits))
    gen_cost = tf.reduce_mean(gen_loss, name='generator_cost')
    
    
    #########################################
    # Generator and Discriminator Optimizers
    #########################################
      
    dis_optimizer = tf.train.AdamOptimizer(learning_rate)
    dis_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
    dis_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
    
    with tf.control_dependencies(dis_update_ops): # required to upd. batch_norm params
        dis_train = dis_optimizer.minimize(dis_cost, var_list=dis_train_vars,
                                           name='train_discriminator')
    
    gen_optimizer = tf.train.AdamOptimizer(learning_rate)
    gen_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
    
    with tf.control_dependencies(gen_update_ops): # required to upd. batch_norm params
        gen_train = gen_optimizer.minimize(gen_cost, var_list=gen_train_vars,
                                           name='train_generator')
    
    # Saver to save session for reuse
    saver = tf.train.Saver()
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:17: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:64: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:65: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.batch_normalization instead.
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:74: conv2d_transpose (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.conv2d_transpose instead.
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:77: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dropout instead.
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:121: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.conv2d instead.
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
In [4]:
##########################
### TRAINING & EVALUATION
##########################

with tf.Session(graph=g) as sess:
    sess.run(tf.global_variables_initializer())
    
    avg_costs = {'discriminator': [], 'generator': []}

    for epoch in range(training_epochs):
        dis_avg_cost, gen_avg_cost = 0., 0.
        total_batch = mnist.train.num_examples // batch_size

        for i in range(total_batch):
            
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            batch_x = batch_x*2 - 1 # normalize
            batch_randsample = np.random.uniform(-1, 1, size=(batch_size, gen_input_size))
            
            # Train
            
            _, dc = sess.run(['train_discriminator', 'discriminator_cost:0'],
                             feed_dict={'discriminator_inputs:0': batch_x, 
                                        'generator_inputs:0': batch_randsample,
                                        'dropout:0': dropout_rate,
                                        'is_training:0': True})
            
            _, gc = sess.run(['train_generator', 'generator_cost:0'],
                             feed_dict={'generator_inputs:0': batch_randsample,
                                        'dropout:0': dropout_rate,
                                        'is_training:0': True})
            
            dis_avg_cost += dc
            gen_avg_cost += gc

            if not i % print_interval:
                print("Minibatch: %04d | Dis/Gen Cost:    %.3f/%.3f" % (i + 1, dc, gc))
                

        print("Epoch:     %04d | Dis/Gen AvgCost: %.3f/%.3f" % 
              (epoch + 1, dis_avg_cost / total_batch, gen_avg_cost / total_batch))
        
        avg_costs['discriminator'].append(dis_avg_cost / total_batch)
        avg_costs['generator'].append(gen_avg_cost / total_batch)
    
    
    saver.save(sess, save_path='./gan-conv.ckpt')
Minibatch: 0001 | Dis/Gen Cost:    1.630/0.866
Minibatch: 0201 | Dis/Gen Cost:    0.850/1.879
Minibatch: 0401 | Dis/Gen Cost:    0.606/2.467
Minibatch: 0601 | Dis/Gen Cost:    0.695/1.661
Minibatch: 0801 | Dis/Gen Cost:    1.149/1.297
Epoch:     0001 | Dis/Gen AvgCost: 0.820/1.887
Minibatch: 0001 | Dis/Gen Cost:    0.707/1.486
Minibatch: 0201 | Dis/Gen Cost:    0.924/1.438
Minibatch: 0401 | Dis/Gen Cost:    0.751/1.508
Minibatch: 0601 | Dis/Gen Cost:    0.899/1.611
Minibatch: 0801 | Dis/Gen Cost:    0.914/1.535
Epoch:     0002 | Dis/Gen AvgCost: 0.954/1.510
Minibatch: 0001 | Dis/Gen Cost:    0.498/1.955
Minibatch: 0201 | Dis/Gen Cost:    0.757/1.670
Minibatch: 0401 | Dis/Gen Cost:    1.100/1.204
Minibatch: 0601 | Dis/Gen Cost:    0.656/2.054
Minibatch: 0801 | Dis/Gen Cost:    1.036/1.174
Epoch:     0003 | Dis/Gen AvgCost: 0.784/1.720
Minibatch: 0001 | Dis/Gen Cost:    1.576/0.992
Minibatch: 0201 | Dis/Gen Cost:    0.663/2.002
Minibatch: 0401 | Dis/Gen Cost:    0.869/1.773
Minibatch: 0601 | Dis/Gen Cost:    0.675/1.772
Minibatch: 0801 | Dis/Gen Cost:    0.881/1.489
Epoch:     0004 | Dis/Gen AvgCost: 0.898/1.575
Minibatch: 0001 | Dis/Gen Cost:    1.201/1.386
Minibatch: 0201 | Dis/Gen Cost:    1.245/1.606
Minibatch: 0401 | Dis/Gen Cost:    1.281/1.015
Minibatch: 0601 | Dis/Gen Cost:    0.925/1.124
Minibatch: 0801 | Dis/Gen Cost:    1.126/1.634
Epoch:     0005 | Dis/Gen AvgCost: 1.037/1.435
Minibatch: 0001 | Dis/Gen Cost:    0.853/1.626
Minibatch: 0201 | Dis/Gen Cost:    1.204/0.929
Minibatch: 0401 | Dis/Gen Cost:    1.070/1.365
Minibatch: 0601 | Dis/Gen Cost:    1.366/0.927
Minibatch: 0801 | Dis/Gen Cost:    1.253/1.500
Epoch:     0006 | Dis/Gen AvgCost: 1.168/1.186
Minibatch: 0001 | Dis/Gen Cost:    1.590/0.945
Minibatch: 0201 | Dis/Gen Cost:    0.822/1.563
Minibatch: 0401 | Dis/Gen Cost:    0.894/1.410
Minibatch: 0601 | Dis/Gen Cost:    1.292/1.131
Minibatch: 0801 | Dis/Gen Cost:    1.361/1.005
Epoch:     0007 | Dis/Gen AvgCost: 1.248/1.103
Minibatch: 0001 | Dis/Gen Cost:    1.860/0.697
Minibatch: 0201 | Dis/Gen Cost:    1.291/0.986
Minibatch: 0401 | Dis/Gen Cost:    1.097/0.934
Minibatch: 0601 | Dis/Gen Cost:    1.316/0.788
Minibatch: 0801 | Dis/Gen Cost:    1.437/0.885
Epoch:     0008 | Dis/Gen AvgCost: 1.298/0.995
Minibatch: 0001 | Dis/Gen Cost:    1.150/1.072
Minibatch: 0201 | Dis/Gen Cost:    1.177/1.148
Minibatch: 0401 | Dis/Gen Cost:    1.351/0.884
Minibatch: 0601 | Dis/Gen Cost:    1.434/0.797
Minibatch: 0801 | Dis/Gen Cost:    1.291/0.929
Epoch:     0009 | Dis/Gen AvgCost: 1.333/0.968
Minibatch: 0001 | Dis/Gen Cost:    1.324/0.764
Minibatch: 0201 | Dis/Gen Cost:    1.255/0.942
Minibatch: 0401 | Dis/Gen Cost:    1.181/1.007
Minibatch: 0601 | Dis/Gen Cost:    1.132/1.134
Minibatch: 0801 | Dis/Gen Cost:    1.170/1.249
Epoch:     0010 | Dis/Gen AvgCost: 1.328/0.922
Minibatch: 0001 | Dis/Gen Cost:    1.539/0.739
Minibatch: 0201 | Dis/Gen Cost:    1.181/1.186
Minibatch: 0401 | Dis/Gen Cost:    1.014/1.331
Minibatch: 0601 | Dis/Gen Cost:    1.380/0.884
Minibatch: 0801 | Dis/Gen Cost:    1.441/0.893
Epoch:     0011 | Dis/Gen AvgCost: 1.306/0.949
Minibatch: 0001 | Dis/Gen Cost:    1.248/0.953
Minibatch: 0201 | Dis/Gen Cost:    1.421/0.751
Minibatch: 0401 | Dis/Gen Cost:    1.323/0.891
Minibatch: 0601 | Dis/Gen Cost:    1.363/0.912
Minibatch: 0801 | Dis/Gen Cost:    1.174/1.112
Epoch:     0012 | Dis/Gen AvgCost: 1.334/0.931
Minibatch: 0001 | Dis/Gen Cost:    1.463/0.792
Minibatch: 0201 | Dis/Gen Cost:    1.296/0.992
Minibatch: 0401 | Dis/Gen Cost:    1.213/1.037
Minibatch: 0601 | Dis/Gen Cost:    1.273/0.899
Minibatch: 0801 | Dis/Gen Cost:    1.282/0.893
Epoch:     0013 | Dis/Gen AvgCost: 1.323/0.910
Minibatch: 0001 | Dis/Gen Cost:    1.192/0.921
Minibatch: 0201 | Dis/Gen Cost:    1.287/0.933
Minibatch: 0401 | Dis/Gen Cost:    1.292/0.898
Minibatch: 0601 | Dis/Gen Cost:    1.164/0.945
Minibatch: 0801 | Dis/Gen Cost:    1.469/0.776
Epoch:     0014 | Dis/Gen AvgCost: 1.312/0.890
Minibatch: 0001 | Dis/Gen Cost:    1.363/0.876
Minibatch: 0201 | Dis/Gen Cost:    1.398/0.759
Minibatch: 0401 | Dis/Gen Cost:    1.099/1.088
Minibatch: 0601 | Dis/Gen Cost:    1.415/0.831
Minibatch: 0801 | Dis/Gen Cost:    1.287/0.813
Epoch:     0015 | Dis/Gen AvgCost: 1.310/0.896
Minibatch: 0001 | Dis/Gen Cost:    1.309/0.910
Minibatch: 0201 | Dis/Gen Cost:    1.397/0.829
Minibatch: 0401 | Dis/Gen Cost:    1.221/0.949
Minibatch: 0601 | Dis/Gen Cost:    1.284/0.918
Minibatch: 0801 | Dis/Gen Cost:    1.315/0.737
Epoch:     0016 | Dis/Gen AvgCost: 1.306/0.860
Minibatch: 0001 | Dis/Gen Cost:    1.193/0.901
Minibatch: 0201 | Dis/Gen Cost:    1.339/0.908
Minibatch: 0401 | Dis/Gen Cost:    1.119/0.969
Minibatch: 0601 | Dis/Gen Cost:    1.293/0.907
Minibatch: 0801 | Dis/Gen Cost:    1.368/0.882
Epoch:     0017 | Dis/Gen AvgCost: 1.320/0.892
Minibatch: 0001 | Dis/Gen Cost:    1.308/1.014
Minibatch: 0201 | Dis/Gen Cost:    1.194/0.936
Minibatch: 0401 | Dis/Gen Cost:    1.536/0.755
Minibatch: 0601 | Dis/Gen Cost:    1.443/0.810
Minibatch: 0801 | Dis/Gen Cost:    1.288/0.730
Epoch:     0018 | Dis/Gen AvgCost: 1.315/0.867
Minibatch: 0001 | Dis/Gen Cost:    1.259/0.979
Minibatch: 0201 | Dis/Gen Cost:    1.307/0.822
Minibatch: 0401 | Dis/Gen Cost:    1.242/0.845
Minibatch: 0601 | Dis/Gen Cost:    1.422/0.891
Minibatch: 0801 | Dis/Gen Cost:    1.263/0.904
Epoch:     0019 | Dis/Gen AvgCost: 1.306/0.866
Minibatch: 0001 | Dis/Gen Cost:    1.204/0.811
Minibatch: 0201 | Dis/Gen Cost:    1.340/0.810
Minibatch: 0401 | Dis/Gen Cost:    1.278/0.963
Minibatch: 0601 | Dis/Gen Cost:    1.249/0.936
Minibatch: 0801 | Dis/Gen Cost:    1.285/0.945
Epoch:     0020 | Dis/Gen AvgCost: 1.316/0.853
Minibatch: 0001 | Dis/Gen Cost:    1.370/0.772
Minibatch: 0201 | Dis/Gen Cost:    1.478/0.762
Minibatch: 0401 | Dis/Gen Cost:    1.440/0.822
Minibatch: 0601 | Dis/Gen Cost:    1.269/0.809
Minibatch: 0801 | Dis/Gen Cost:    1.260/0.923
Epoch:     0021 | Dis/Gen AvgCost: 1.324/0.837
Minibatch: 0001 | Dis/Gen Cost:    1.401/0.892
Minibatch: 0201 | Dis/Gen Cost:    1.361/0.762
Minibatch: 0401 | Dis/Gen Cost:    1.121/1.012
Minibatch: 0601 | Dis/Gen Cost:    1.366/0.822
Minibatch: 0801 | Dis/Gen Cost:    1.484/0.744
Epoch:     0022 | Dis/Gen AvgCost: 1.314/0.851
Minibatch: 0001 | Dis/Gen Cost:    1.207/0.829
Minibatch: 0201 | Dis/Gen Cost:    1.320/0.786
Minibatch: 0401 | Dis/Gen Cost:    1.327/0.807
Minibatch: 0601 | Dis/Gen Cost:    1.250/0.909
Minibatch: 0801 | Dis/Gen Cost:    1.339/0.769
Epoch:     0023 | Dis/Gen AvgCost: 1.323/0.833
Minibatch: 0001 | Dis/Gen Cost:    1.363/0.825
Minibatch: 0201 | Dis/Gen Cost:    1.416/0.738
Minibatch: 0401 | Dis/Gen Cost:    1.290/0.876
Minibatch: 0601 | Dis/Gen Cost:    1.257/0.825
Minibatch: 0801 | Dis/Gen Cost:    1.510/0.633
Epoch:     0024 | Dis/Gen AvgCost: 1.323/0.841
Minibatch: 0001 | Dis/Gen Cost:    1.291/0.694
Minibatch: 0201 | Dis/Gen Cost:    1.400/0.720
Minibatch: 0401 | Dis/Gen Cost:    1.340/0.802
Minibatch: 0601 | Dis/Gen Cost:    1.339/0.784
Minibatch: 0801 | Dis/Gen Cost:    1.211/0.886
Epoch:     0025 | Dis/Gen AvgCost: 1.339/0.811
Minibatch: 0001 | Dis/Gen Cost:    1.395/0.865
Minibatch: 0201 | Dis/Gen Cost:    1.400/0.823
Minibatch: 0401 | Dis/Gen Cost:    1.357/0.811
Minibatch: 0601 | Dis/Gen Cost:    1.404/0.741
Minibatch: 0801 | Dis/Gen Cost:    1.298/0.930
Epoch:     0026 | Dis/Gen AvgCost: 1.340/0.819
Minibatch: 0001 | Dis/Gen Cost:    1.257/0.833
Minibatch: 0201 | Dis/Gen Cost:    1.359/0.772
Minibatch: 0401 | Dis/Gen Cost:    1.453/0.798
Minibatch: 0601 | Dis/Gen Cost:    1.389/0.853
Minibatch: 0801 | Dis/Gen Cost:    1.447/0.754
Epoch:     0027 | Dis/Gen AvgCost: 1.340/0.808
Minibatch: 0001 | Dis/Gen Cost:    1.353/0.764
Minibatch: 0201 | Dis/Gen Cost:    1.353/0.811
Minibatch: 0401 | Dis/Gen Cost:    1.458/0.748
Minibatch: 0601 | Dis/Gen Cost:    1.448/0.753
Minibatch: 0801 | Dis/Gen Cost:    1.475/0.696
Epoch:     0028 | Dis/Gen AvgCost: 1.349/0.792
Minibatch: 0001 | Dis/Gen Cost:    1.271/0.932
Minibatch: 0201 | Dis/Gen Cost:    1.294/0.894
Minibatch: 0401 | Dis/Gen Cost:    1.156/0.866
Minibatch: 0601 | Dis/Gen Cost:    1.292/0.778
Minibatch: 0801 | Dis/Gen Cost:    1.309/0.817
Epoch:     0029 | Dis/Gen AvgCost: 1.347/0.799
Minibatch: 0001 | Dis/Gen Cost:    1.459/0.727
Minibatch: 0201 | Dis/Gen Cost:    1.396/0.753
Minibatch: 0401 | Dis/Gen Cost:    1.367/0.754
Minibatch: 0601 | Dis/Gen Cost:    1.336/0.785
Minibatch: 0801 | Dis/Gen Cost:    1.304/0.756
Epoch:     0030 | Dis/Gen AvgCost: 1.347/0.780
Minibatch: 0001 | Dis/Gen Cost:    1.431/0.726
Minibatch: 0201 | Dis/Gen Cost:    1.348/0.793
Minibatch: 0401 | Dis/Gen Cost:    1.102/0.823
Minibatch: 0601 | Dis/Gen Cost:    1.276/0.772
Minibatch: 0801 | Dis/Gen Cost:    1.390/0.776
Epoch:     0031 | Dis/Gen AvgCost: 1.337/0.801
Minibatch: 0001 | Dis/Gen Cost:    1.507/0.704
Minibatch: 0201 | Dis/Gen Cost:    1.295/0.873
Minibatch: 0401 | Dis/Gen Cost:    1.312/0.835
Minibatch: 0601 | Dis/Gen Cost:    1.346/0.842
Minibatch: 0801 | Dis/Gen Cost:    1.328/0.721
Epoch:     0032 | Dis/Gen AvgCost: 1.342/0.792
Minibatch: 0001 | Dis/Gen Cost:    1.401/0.717
Minibatch: 0201 | Dis/Gen Cost:    1.436/0.737
Minibatch: 0401 | Dis/Gen Cost:    1.332/0.774
Minibatch: 0601 | Dis/Gen Cost:    1.311/0.804
Minibatch: 0801 | Dis/Gen Cost:    1.391/0.650
Epoch:     0033 | Dis/Gen AvgCost: 1.352/0.783
Minibatch: 0001 | Dis/Gen Cost:    1.317/0.740
Minibatch: 0201 | Dis/Gen Cost:    1.343/0.810
Minibatch: 0401 | Dis/Gen Cost:    1.394/0.717
Minibatch: 0601 | Dis/Gen Cost:    1.455/0.779
Minibatch: 0801 | Dis/Gen Cost:    1.445/0.704
Epoch:     0034 | Dis/Gen AvgCost: 1.348/0.785
Minibatch: 0001 | Dis/Gen Cost:    1.294/0.791
Minibatch: 0201 | Dis/Gen Cost:    1.277/0.886
Minibatch: 0401 | Dis/Gen Cost:    1.349/0.721
Minibatch: 0601 | Dis/Gen Cost:    1.297/0.717
Minibatch: 0801 | Dis/Gen Cost:    1.320/0.777
Epoch:     0035 | Dis/Gen AvgCost: 1.353/0.780
Minibatch: 0001 | Dis/Gen Cost:    1.338/0.756
Minibatch: 0201 | Dis/Gen Cost:    1.273/0.778
Minibatch: 0401 | Dis/Gen Cost:    1.325/0.865
Minibatch: 0601 | Dis/Gen Cost:    1.438/0.717
Minibatch: 0801 | Dis/Gen Cost:    1.328/0.785
Epoch:     0036 | Dis/Gen AvgCost: 1.352/0.770
Minibatch: 0001 | Dis/Gen Cost:    1.375/0.764
Minibatch: 0201 | Dis/Gen Cost:    1.453/0.723
Minibatch: 0401 | Dis/Gen Cost:    1.270/0.807
Minibatch: 0601 | Dis/Gen Cost:    1.392/0.775
Minibatch: 0801 | Dis/Gen Cost:    1.318/0.824
Epoch:     0037 | Dis/Gen AvgCost: 1.353/0.773
Minibatch: 0001 | Dis/Gen Cost:    1.270/0.874
Minibatch: 0201 | Dis/Gen Cost:    1.214/0.833
Minibatch: 0401 | Dis/Gen Cost:    1.456/0.666
Minibatch: 0601 | Dis/Gen Cost:    1.400/0.824
Minibatch: 0801 | Dis/Gen Cost:    1.328/0.736
Epoch:     0038 | Dis/Gen AvgCost: 1.354/0.776
Minibatch: 0001 | Dis/Gen Cost:    1.332/0.743
Minibatch: 0201 | Dis/Gen Cost:    1.389/0.710
Minibatch: 0401 | Dis/Gen Cost:    1.375/0.708
Minibatch: 0601 | Dis/Gen Cost:    1.296/0.758
Minibatch: 0801 | Dis/Gen Cost:    1.337/0.783
Epoch:     0039 | Dis/Gen AvgCost: 1.356/0.765
Minibatch: 0001 | Dis/Gen Cost:    1.388/0.706
Minibatch: 0201 | Dis/Gen Cost:    1.371/0.712
Minibatch: 0401 | Dis/Gen Cost:    1.349/0.698
Minibatch: 0601 | Dis/Gen Cost:    1.380/0.723
Minibatch: 0801 | Dis/Gen Cost:    1.371/0.746
Epoch:     0040 | Dis/Gen AvgCost: 1.358/0.759
Minibatch: 0001 | Dis/Gen Cost:    1.349/0.702
Minibatch: 0201 | Dis/Gen Cost:    1.315/0.742
Minibatch: 0401 | Dis/Gen Cost:    1.353/0.760
Minibatch: 0601 | Dis/Gen Cost:    1.335/0.799
Minibatch: 0801 | Dis/Gen Cost:    1.403/0.726
Epoch:     0041 | Dis/Gen AvgCost: 1.362/0.755
Minibatch: 0001 | Dis/Gen Cost:    1.363/0.782
Minibatch: 0201 | Dis/Gen Cost:    1.335/0.742
Minibatch: 0401 | Dis/Gen Cost:    1.344/0.751
Minibatch: 0601 | Dis/Gen Cost:    1.338/0.740
Minibatch: 0801 | Dis/Gen Cost:    1.460/0.735
Epoch:     0042 | Dis/Gen AvgCost: 1.361/0.764
Minibatch: 0001 | Dis/Gen Cost:    1.308/0.767
Minibatch: 0201 | Dis/Gen Cost:    1.367/0.764
Minibatch: 0401 | Dis/Gen Cost:    1.382/0.764
Minibatch: 0601 | Dis/Gen Cost:    1.419/0.625
Minibatch: 0801 | Dis/Gen Cost:    1.393/0.777
Epoch:     0043 | Dis/Gen AvgCost: 1.361/0.753
Minibatch: 0001 | Dis/Gen Cost:    1.413/0.749
Minibatch: 0201 | Dis/Gen Cost:    1.370/0.724
Minibatch: 0401 | Dis/Gen Cost:    1.314/0.756
Minibatch: 0601 | Dis/Gen Cost:    1.321/0.763
Minibatch: 0801 | Dis/Gen Cost:    1.354/0.771
Epoch:     0044 | Dis/Gen AvgCost: 1.364/0.752
Minibatch: 0001 | Dis/Gen Cost:    1.363/0.748
Minibatch: 0201 | Dis/Gen Cost:    1.365/0.727
Minibatch: 0401 | Dis/Gen Cost:    1.439/0.714
Minibatch: 0601 | Dis/Gen Cost:    1.429/0.696
Minibatch: 0801 | Dis/Gen Cost:    1.427/0.699
Epoch:     0045 | Dis/Gen AvgCost: 1.363/0.745
Minibatch: 0001 | Dis/Gen Cost:    1.398/0.713
Minibatch: 0201 | Dis/Gen Cost:    1.408/0.717
Minibatch: 0401 | Dis/Gen Cost:    1.298/0.734
Minibatch: 0601 | Dis/Gen Cost:    1.345/0.805
Minibatch: 0801 | Dis/Gen Cost:    1.331/0.828
Epoch:     0046 | Dis/Gen AvgCost: 1.366/0.752
Minibatch: 0001 | Dis/Gen Cost:    1.319/0.751
Minibatch: 0201 | Dis/Gen Cost:    1.482/0.713
Minibatch: 0401 | Dis/Gen Cost:    1.341/0.803
Minibatch: 0601 | Dis/Gen Cost:    1.386/0.651
Minibatch: 0801 | Dis/Gen Cost:    1.428/0.701
Epoch:     0047 | Dis/Gen AvgCost: 1.369/0.758
Minibatch: 0001 | Dis/Gen Cost:    1.378/0.747
Minibatch: 0201 | Dis/Gen Cost:    1.355/0.716
Minibatch: 0401 | Dis/Gen Cost:    1.357/0.686
Minibatch: 0601 | Dis/Gen Cost:    1.333/0.767
Minibatch: 0801 | Dis/Gen Cost:    1.380/0.712
Epoch:     0048 | Dis/Gen AvgCost: 1.370/0.735
Minibatch: 0001 | Dis/Gen Cost:    1.409/0.706
Minibatch: 0201 | Dis/Gen Cost:    1.307/0.789
Minibatch: 0401 | Dis/Gen Cost:    1.396/0.731
Minibatch: 0601 | Dis/Gen Cost:    1.375/0.711
Minibatch: 0801 | Dis/Gen Cost:    1.365/0.782
Epoch:     0049 | Dis/Gen AvgCost: 1.371/0.733
Minibatch: 0001 | Dis/Gen Cost:    1.409/0.701
Minibatch: 0201 | Dis/Gen Cost:    1.369/0.728
Minibatch: 0401 | Dis/Gen Cost:    1.315/0.730
Minibatch: 0601 | Dis/Gen Cost:    1.321/0.774
Minibatch: 0801 | Dis/Gen Cost:    1.336/0.735
Epoch:     0050 | Dis/Gen AvgCost: 1.372/0.735
In [5]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.plot(range(len(avg_costs['discriminator'])), 
         avg_costs['discriminator'], label='discriminator')
plt.plot(range(len(avg_costs['generator'])),
         avg_costs['generator'], label='generator')
plt.legend()
plt.show()
In [6]:
####################################
### RELOAD & GENERATE SAMPLE IMAGES
####################################


n_examples = 25

with tf.Session(graph=g) as sess:
    saver.restore(sess, save_path='./gan-conv.ckpt')

    batch_randsample = np.random.uniform(-1, 1, size=(n_examples, gen_input_size))
    new_examples = sess.run('generator/generator_outputs:0',
                            feed_dict={'generator_inputs:0': batch_randsample,
                                       'dropout:0': 0.0,
                                       'is_training:0': False})

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(8, 8),
                         sharey=True, sharex=True)

for image, ax in zip(new_examples, axes.flatten()):
    ax.imshow(image.reshape((dis_input_size // 28, dis_input_size // 28)), cmap='binary')

plt.show()
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./gan-conv.ckpt