This notebook is based on

NOTE: when using adam(and nadam) for optimizer, loss tends to nan

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Input, Dense, Lambda, Reshape, Activation
from keras.activations import softmax
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
Using Theano backend.
Using gpu device 0: Tesla K40m (CNMeM is enabled with initial size: 95.0% of memory, cuDNN 5105)
/home/s1621631/.pyenv/versions/anaconda3-4.0.0/lib/python3.5/site-packages/theano/sandbox/cuda/__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.
  warnings.warn(warn)
In [2]:
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot, plot
from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline
In [3]:
batch_size = 100
original_dim = 784
latent_dim = 10
categorical_dim = 20 # one-of-K vector
intermediate_dim = 392
nb_epoch = 30
inv_tau = 20
In [4]:
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
alpha = Dense(latent_dim*categorical_dim)(h)
reshape_alpha = Reshape((latent_dim, categorical_dim, ))(alpha)
pi = Activation('softmax')(reshape_alpha)
In [5]:
def gumbel_softmax(pi):
    gumbel_softmax_arg = (K.log(pi+K.epsilon()) 
                          - K.log(-K.log(K.random_uniform_variable((latent_dim, categorical_dim), 0., 1.))))*inv_tau
    y = softmax(K.reshape(gumbel_softmax_arg, (-1, latent_dim, categorical_dim)))
    return K.reshape(y, (-1, latent_dim*categorical_dim))
In [6]:
z = Lambda(gumbel_softmax, output_shape=(latent_dim*categorical_dim,))(pi)

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

vae = Model(input=x, output=x_decoded_mean)
SVG(model_to_dot(vae, show_shapes=True).create(prog='dot', format='svg'))
Out[6]:
G 140717682314880 input_1: InputLayer input: output: (None, 784) (None, 784) 140717682314936 dense_1: Dense input: output: (None, 784) (None, 392) 140717682314880->140717682314936 140717682344232 dense_2: Dense input: output: (None, 392) (None, 200) 140717682314936->140717682344232 140717676960456 reshape_1: Reshape input: output: (None, 200) (None, 10, 20) 140717682344232->140717676960456 140717682346472 activation_1: Activation input: output: (None, 10, 20) (None, 10, 20) 140717676960456->140717682346472 140718312806048 lambda_1: Lambda input: output: (None, 10, 20) (None, 200) 140717682346472->140718312806048 140717676733888 dense_3: Dense input: output: (None, 200) (None, 392) 140718312806048->140717676733888 140717676286136 dense_4: Dense input: output: (None, 392) (None, 784) 140717676733888->140717676286136
In [7]:
def vae_loss(x, x_decoded_mean):
    xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = K.sum(pi*K.log(categorical_dim*pi), axis=(-1, -2))
    return xent_loss - kl_loss
In [8]:
vae.compile(optimizer='adadelta', loss=vae_loss)
In [9]:
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
In [10]:
vae.fit(x_train, x_train,
        shuffle=True,
        nb_epoch=nb_epoch,
        batch_size=batch_size,
        validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 1s - loss: 161.4843 - val_loss: 135.5699
Epoch 2/30
60000/60000 [==============================] - 1s - loss: 133.3479 - val_loss: 131.0892
Epoch 3/30
60000/60000 [==============================] - 1s - loss: 130.5253 - val_loss: 129.3895
Epoch 4/30
60000/60000 [==============================] - 1s - loss: 128.9574 - val_loss: 128.1032
Epoch 5/30
60000/60000 [==============================] - 1s - loss: 127.9940 - val_loss: 126.8326
Epoch 6/30
60000/60000 [==============================] - 1s - loss: 127.0901 - val_loss: 126.5246
Epoch 7/30
60000/60000 [==============================] - 1s - loss: 126.3401 - val_loss: 125.8059
Epoch 8/30
60000/60000 [==============================] - 1s - loss: 125.7193 - val_loss: 125.0025
Epoch 9/30
60000/60000 [==============================] - 1s - loss: 125.1603 - val_loss: 124.7052
Epoch 10/30
60000/60000 [==============================] - 1s - loss: 124.7295 - val_loss: 124.2207
Epoch 11/30
60000/60000 [==============================] - 1s - loss: 124.3751 - val_loss: 123.8167
Epoch 12/30
60000/60000 [==============================] - 1s - loss: 123.9529 - val_loss: 123.6260
Epoch 13/30
60000/60000 [==============================] - 1s - loss: 123.7401 - val_loss: 123.5448
Epoch 14/30
60000/60000 [==============================] - 1s - loss: 123.5216 - val_loss: 123.6845
Epoch 15/30
60000/60000 [==============================] - 1s - loss: 123.3635 - val_loss: 123.3390
Epoch 16/30
60000/60000 [==============================] - 1s - loss: 123.2087 - val_loss: 122.8693
Epoch 17/30
60000/60000 [==============================] - 1s - loss: 122.9511 - val_loss: 122.7534
Epoch 18/30
60000/60000 [==============================] - 1s - loss: 122.7091 - val_loss: 122.4586
Epoch 19/30
60000/60000 [==============================] - 1s - loss: 122.5145 - val_loss: 122.2344
Epoch 20/30
60000/60000 [==============================] - 1s - loss: 122.4052 - val_loss: 122.3109
Epoch 21/30
60000/60000 [==============================] - 1s - loss: 122.2146 - val_loss: 122.2642
Epoch 22/30
60000/60000 [==============================] - 1s - loss: 122.1971 - val_loss: 122.6890
Epoch 23/30
60000/60000 [==============================] - 1s - loss: 121.9808 - val_loss: 122.1780
Epoch 24/30
60000/60000 [==============================] - 1s - loss: 121.9976 - val_loss: 122.1861
Epoch 25/30
60000/60000 [==============================] - 1s - loss: 121.8329 - val_loss: 121.8738
Epoch 26/30
60000/60000 [==============================] - 1s - loss: 121.8908 - val_loss: 122.0048
Epoch 27/30
60000/60000 [==============================] - 1s - loss: 121.8052 - val_loss: 121.7829
Epoch 28/30
60000/60000 [==============================] - 1s - loss: 121.6288 - val_loss: 121.6008
Epoch 29/30
60000/60000 [==============================] - 1s - loss: 121.6143 - val_loss: 122.2590
Epoch 30/30
60000/60000 [==============================] - 1s - loss: 121.5304 - val_loss: 121.5872
Out[10]:
<keras.callbacks.History at 0x7ffb550ddb38>
In [11]:
# build a model to project inputs on the latent space
encoder = Model(x, pi)
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
In [12]:
plt.figure(figsize=(10, 20))

n = 10
im_index = 1

for i in range(1, n+1):
    plt.subplot(n, 2, im_index)
    plt.axis('off')
    plt.imshow(x_test[i-1].reshape((28,28)))
    im_index += 1

    plt.subplot(n, 2, im_index)
    ax = sns.heatmap(x_test_encoded[i-1])
    im_index += 1
plt.savefig('gumbel_vae_mnist_encode.pdf')
plt.show()
In [13]:
# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim*categorical_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)
In [14]:
n = 14
digit_size = 28
figure = np.zeros((digit_size*n, digit_size*n))
# random sampling
for i in range(n):
    for j in range(n):
        data = np.zeros(categorical_dim*latent_dim)
        for index, nonzero_index in enumerate(np.random.randint(0, categorical_dim, latent_dim)):
            data[nonzero_index+index*categorical_dim] += 1
                        
        z_sample = np.array([data])
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(figure)
plt.savefig('gumbel_vae_mnist_decode.pdf')
plt.show()