From expectation maximization to stochastic variational inference

Update, Dec. 17th 2019: This notebook is superseded by the following two notebooks:

The following old variational autoencoder code[1] is still used in other notebooks and kept here for further reference.

In [1]:
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import cm

import keras
from keras import backend as K
from keras import layers
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.utils import to_categorical

%matplotlib inline
Using TensorFlow backend.
In [2]:
# Dimensions of MNIST images  
image_shape = (28, 28, 1)

# Dimension of latent space
latent_dim = 2

# Mini-batch size for training
batch_size = 128

def create_encoder():
    '''
    Creates a convolutional encoder model for MNIST images.
    
    - Input for the created model are MNIST images. 
    - Output of the created model are the sufficient statistics
      of the variational distriution q(t|x;phi), mean and log 
      variance. 
    '''
    encoder_iput = layers.Input(shape=image_shape)
    
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(encoder_iput)
    x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(32, activation='relu')(x)

    t_mean = layers.Dense(latent_dim)(x)
    t_log_var = layers.Dense(latent_dim)(x)

    return Model(encoder_iput, [t_mean, t_log_var], name='encoder')

def create_decoder():
    '''
    Creates a (de-)convolutional decoder model for MNIST images.
    
    - Input for the created model are latent vectors t.
    - Output of the model are images of shape (28, 28, 1) where
      the value of each pixel is the probability of being white.
    '''
    decoder_input = layers.Input(shape=(latent_dim,))
    
    x = layers.Dense(12544, activation='relu')(decoder_input)
    x = layers.Reshape((14, 14, 64))(x)
    x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
    x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
    
    return Model(decoder_input, x, name='decoder')
In [3]:
def sample(args):
    '''
    Draws samples from a standard normal and scales the samples with
    standard deviation of the variational distribution and shifts them
    by the mean.
    
    Args:
        args: sufficient statistics of the variational distribution.
        
    Returns:
        Samples from the variational distribution.
    '''
    t_mean, t_log_var = args
    t_sigma = K.sqrt(K.exp(t_log_var))
    epsilon = K.random_normal(shape=K.shape(t_mean), mean=0., stddev=1.)
    return t_mean + t_sigma * epsilon

def create_sampler():
    '''
    Creates a sampling layer.
    '''
    return layers.Lambda(sample, name='sampler')
In [4]:
encoder = create_encoder()
decoder = create_decoder()
sampler = create_sampler()

x = layers.Input(shape=image_shape)
t_mean, t_log_var = encoder(x)
t = sampler([t_mean, t_log_var])
t_decoded = decoder(t)

vae = Model(x, t_decoded, name='vae')
In [5]:
def neg_variational_lower_bound(x, t_decoded):
    '''
    Negative variational lower bound used as loss function
    for training the variational auto-encoder.
    
    Args:
        x: input images
        t_decoded: reconstructed images
    '''
    # Reconstruction loss
    rc_loss = K.sum(K.binary_crossentropy(
        K.batch_flatten(x), 
        K.batch_flatten(t_decoded)), axis=-1)

    # Regularization term (KL divergence)
    kl_loss = -0.5 * K.sum(1 + t_log_var \
                             - K.square(t_mean) \
                             - K.exp(t_log_var), axis=-1)
    
    # Average over mini-batch
    return K.mean(rc_loss + kl_loss)
In [6]:
# MNIST training and validation data
(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))

x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

# Compile variational auto-encoder model
vae.compile(optimizer='rmsprop', loss=neg_variational_lower_bound)

# Train variational auto-encoder with MNIST images
vae.fit(x=x_train, 
         y=x_train,
         epochs=25,
         shuffle=True,
         batch_size=batch_size,
         validation_data=(x_test, x_test), verbose=2)
Train on 60000 samples, validate on 10000 samples
Epoch 1/25
 - 20s - loss: 195.2594 - val_loss: 174.2528
Epoch 2/25
 - 19s - loss: 171.3044 - val_loss: 166.7014
Epoch 3/25
 - 19s - loss: 164.3512 - val_loss: 161.6999
Epoch 4/25
 - 19s - loss: 160.1977 - val_loss: 161.6477
Epoch 5/25
 - 19s - loss: 157.7647 - val_loss: 156.3657
Epoch 6/25
 - 19s - loss: 156.1030 - val_loss: 155.2672
Epoch 7/25
 - 19s - loss: 154.8048 - val_loss: 155.3826
Epoch 8/25
 - 19s - loss: 153.7678 - val_loss: 152.8838
Epoch 9/25
 - 19s - loss: 152.8601 - val_loss: 153.5024
Epoch 10/25
 - 19s - loss: 152.0742 - val_loss: 152.3024
Epoch 11/25
 - 19s - loss: 151.3636 - val_loss: 152.1116
Epoch 12/25
 - 19s - loss: 150.7016 - val_loss: 152.2413
Epoch 13/25
 - 19s - loss: 150.1563 - val_loss: 151.3501
Epoch 14/25
 - 19s - loss: 149.6736 - val_loss: 149.6158
Epoch 15/25
 - 19s - loss: 149.1454 - val_loss: 149.2047
Epoch 16/25
 - 19s - loss: 148.7548 - val_loss: 149.0340
Epoch 17/25
 - 19s - loss: 148.3118 - val_loss: 148.8142
Epoch 18/25
 - 19s - loss: 147.9456 - val_loss: 149.1827
Epoch 19/25
 - 19s - loss: 147.6405 - val_loss: 151.5525
Epoch 20/25
 - 19s - loss: 147.2038 - val_loss: 148.2406
Epoch 21/25
 - 19s - loss: 146.8990 - val_loss: 150.3239
Epoch 22/25
 - 19s - loss: 146.6422 - val_loss: 147.5277
Epoch 23/25
 - 18s - loss: 146.3777 - val_loss: 148.2235
Epoch 24/25
 - 19s - loss: 146.0520 - val_loss: 147.7608
Epoch 25/25
 - 19s - loss: 145.8438 - val_loss: 147.9813
Out[6]:
<keras.callbacks.History at 0x7fbaf4e40da0>

References

[1] Fran├žois Chollet. Deep Learning with Python.