Update, Dec. 17th 2019: This notebook is superseded by the following two notebooks:
The following old variational autoencoder code[1] is still used in other notebooks and kept here for further reference.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import keras
from keras import backend as K
from keras import layers
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.utils import to_categorical
%matplotlib inline
Using TensorFlow backend.
# Dimensions of MNIST images
image_shape = (28, 28, 1)
# Dimension of latent space
latent_dim = 2
# Mini-batch size for training
batch_size = 128
def create_encoder():
'''
Creates a convolutional encoder model for MNIST images.
- Input for the created model are MNIST images.
- Output of the created model are the sufficient statistics
of the variational distriution q(t|x;phi), mean and log
variance.
'''
encoder_iput = layers.Input(shape=image_shape)
x = layers.Conv2D(32, 3, padding='same', activation='relu')(encoder_iput)
x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)
t_mean = layers.Dense(latent_dim)(x)
t_log_var = layers.Dense(latent_dim)(x)
return Model(encoder_iput, [t_mean, t_log_var], name='encoder')
def create_decoder():
'''
Creates a (de-)convolutional decoder model for MNIST images.
- Input for the created model are latent vectors t.
- Output of the model are images of shape (28, 28, 1) where
the value of each pixel is the probability of being white.
'''
decoder_input = layers.Input(shape=(latent_dim,))
x = layers.Dense(12544, activation='relu')(decoder_input)
x = layers.Reshape((14, 14, 64))(x)
x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
return Model(decoder_input, x, name='decoder')
def sample(args):
'''
Draws samples from a standard normal and scales the samples with
standard deviation of the variational distribution and shifts them
by the mean.
Args:
args: sufficient statistics of the variational distribution.
Returns:
Samples from the variational distribution.
'''
t_mean, t_log_var = args
t_sigma = K.sqrt(K.exp(t_log_var))
epsilon = K.random_normal(shape=K.shape(t_mean), mean=0., stddev=1.)
return t_mean + t_sigma * epsilon
def create_sampler():
'''
Creates a sampling layer.
'''
return layers.Lambda(sample, name='sampler')
encoder = create_encoder()
decoder = create_decoder()
sampler = create_sampler()
x = layers.Input(shape=image_shape)
t_mean, t_log_var = encoder(x)
t = sampler([t_mean, t_log_var])
t_decoded = decoder(t)
vae = Model(x, t_decoded, name='vae')
def neg_variational_lower_bound(x, t_decoded):
'''
Negative variational lower bound used as loss function
for training the variational autoencoder.
Args:
x: input images
t_decoded: reconstructed images
'''
# Reconstruction loss
rc_loss = K.sum(K.binary_crossentropy(
K.batch_flatten(x),
K.batch_flatten(t_decoded)), axis=-1)
# Regularization term (KL divergence)
kl_loss = -0.5 * K.sum(1 + t_log_var \
- K.square(t_mean) \
- K.exp(t_log_var), axis=-1)
# Average over mini-batch
return K.mean(rc_loss + kl_loss)
# MNIST training and validation data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))
# Compile variational autoencoder model
vae.compile(optimizer='rmsprop', loss=neg_variational_lower_bound)
# Train variational autoencoder with MNIST images
vae.fit(x=x_train,
y=x_train,
epochs=25,
shuffle=True,
batch_size=batch_size,
validation_data=(x_test, x_test), verbose=2)
Train on 60000 samples, validate on 10000 samples Epoch 1/25 - 20s - loss: 195.2594 - val_loss: 174.2528 Epoch 2/25 - 19s - loss: 171.3044 - val_loss: 166.7014 Epoch 3/25 - 19s - loss: 164.3512 - val_loss: 161.6999 Epoch 4/25 - 19s - loss: 160.1977 - val_loss: 161.6477 Epoch 5/25 - 19s - loss: 157.7647 - val_loss: 156.3657 Epoch 6/25 - 19s - loss: 156.1030 - val_loss: 155.2672 Epoch 7/25 - 19s - loss: 154.8048 - val_loss: 155.3826 Epoch 8/25 - 19s - loss: 153.7678 - val_loss: 152.8838 Epoch 9/25 - 19s - loss: 152.8601 - val_loss: 153.5024 Epoch 10/25 - 19s - loss: 152.0742 - val_loss: 152.3024 Epoch 11/25 - 19s - loss: 151.3636 - val_loss: 152.1116 Epoch 12/25 - 19s - loss: 150.7016 - val_loss: 152.2413 Epoch 13/25 - 19s - loss: 150.1563 - val_loss: 151.3501 Epoch 14/25 - 19s - loss: 149.6736 - val_loss: 149.6158 Epoch 15/25 - 19s - loss: 149.1454 - val_loss: 149.2047 Epoch 16/25 - 19s - loss: 148.7548 - val_loss: 149.0340 Epoch 17/25 - 19s - loss: 148.3118 - val_loss: 148.8142 Epoch 18/25 - 19s - loss: 147.9456 - val_loss: 149.1827 Epoch 19/25 - 19s - loss: 147.6405 - val_loss: 151.5525 Epoch 20/25 - 19s - loss: 147.2038 - val_loss: 148.2406 Epoch 21/25 - 19s - loss: 146.8990 - val_loss: 150.3239 Epoch 22/25 - 19s - loss: 146.6422 - val_loss: 147.5277 Epoch 23/25 - 18s - loss: 146.3777 - val_loss: 148.2235 Epoch 24/25 - 19s - loss: 146.0520 - val_loss: 147.7608 Epoch 25/25 - 19s - loss: 145.8438 - val_loss: 147.9813
<keras.callbacks.History at 0x7fbaf4e40da0>
[1] François Chollet. Deep Learning with Python.