This notebook is based on
NOTE: when using adam(and nadam) for optimizer, loss tends to nan
¶import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Reshape, Activation
from keras.activations import softmax
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
Using Theano backend. Using gpu device 0: Tesla K40m (CNMeM is enabled with initial size: 95.0% of memory, cuDNN 5105) /home/s1621631/.pyenv/versions/anaconda3-4.0.0/lib/python3.5/site-packages/theano/sandbox/cuda/__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5. warnings.warn(warn)
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot, plot
from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline
batch_size = 100
original_dim = 784
latent_dim = 10
categorical_dim = 20 # one-of-K vector
intermediate_dim = 392
nb_epoch = 30
inv_tau = 20
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
alpha = Dense(latent_dim*categorical_dim)(h)
reshape_alpha = Reshape((latent_dim, categorical_dim, ))(alpha)
pi = Activation('softmax')(reshape_alpha)
def gumbel_softmax(pi):
gumbel_softmax_arg = (K.log(pi+K.epsilon())
- K.log(-K.log(K.random_uniform_variable((latent_dim, categorical_dim), 0., 1.))))*inv_tau
y = softmax(K.reshape(gumbel_softmax_arg, (-1, latent_dim, categorical_dim)))
return K.reshape(y, (-1, latent_dim*categorical_dim))
z = Lambda(gumbel_softmax, output_shape=(latent_dim*categorical_dim,))(pi)
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
vae = Model(input=x, output=x_decoded_mean)
SVG(model_to_dot(vae, show_shapes=True).create(prog='dot', format='svg'))
def vae_loss(x, x_decoded_mean):
xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = K.sum(pi*K.log(categorical_dim*pi), axis=(-1, -2))
return xent_loss - kl_loss
vae.compile(optimizer='adadelta', loss=vae_loss)
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
vae.fit(x_train, x_train,
shuffle=True,
nb_epoch=nb_epoch,
batch_size=batch_size,
validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 1s - loss: 161.4843 - val_loss: 135.5699 Epoch 2/30 60000/60000 [==============================] - 1s - loss: 133.3479 - val_loss: 131.0892 Epoch 3/30 60000/60000 [==============================] - 1s - loss: 130.5253 - val_loss: 129.3895 Epoch 4/30 60000/60000 [==============================] - 1s - loss: 128.9574 - val_loss: 128.1032 Epoch 5/30 60000/60000 [==============================] - 1s - loss: 127.9940 - val_loss: 126.8326 Epoch 6/30 60000/60000 [==============================] - 1s - loss: 127.0901 - val_loss: 126.5246 Epoch 7/30 60000/60000 [==============================] - 1s - loss: 126.3401 - val_loss: 125.8059 Epoch 8/30 60000/60000 [==============================] - 1s - loss: 125.7193 - val_loss: 125.0025 Epoch 9/30 60000/60000 [==============================] - 1s - loss: 125.1603 - val_loss: 124.7052 Epoch 10/30 60000/60000 [==============================] - 1s - loss: 124.7295 - val_loss: 124.2207 Epoch 11/30 60000/60000 [==============================] - 1s - loss: 124.3751 - val_loss: 123.8167 Epoch 12/30 60000/60000 [==============================] - 1s - loss: 123.9529 - val_loss: 123.6260 Epoch 13/30 60000/60000 [==============================] - 1s - loss: 123.7401 - val_loss: 123.5448 Epoch 14/30 60000/60000 [==============================] - 1s - loss: 123.5216 - val_loss: 123.6845 Epoch 15/30 60000/60000 [==============================] - 1s - loss: 123.3635 - val_loss: 123.3390 Epoch 16/30 60000/60000 [==============================] - 1s - loss: 123.2087 - val_loss: 122.8693 Epoch 17/30 60000/60000 [==============================] - 1s - loss: 122.9511 - val_loss: 122.7534 Epoch 18/30 60000/60000 [==============================] - 1s - loss: 122.7091 - val_loss: 122.4586 Epoch 19/30 60000/60000 [==============================] - 1s - loss: 122.5145 - val_loss: 122.2344 Epoch 20/30 60000/60000 [==============================] - 1s - loss: 122.4052 - val_loss: 122.3109 Epoch 21/30 60000/60000 [==============================] - 1s - loss: 122.2146 - val_loss: 122.2642 Epoch 22/30 60000/60000 [==============================] - 1s - loss: 122.1971 - val_loss: 122.6890 Epoch 23/30 60000/60000 [==============================] - 1s - loss: 121.9808 - val_loss: 122.1780 Epoch 24/30 60000/60000 [==============================] - 1s - loss: 121.9976 - val_loss: 122.1861 Epoch 25/30 60000/60000 [==============================] - 1s - loss: 121.8329 - val_loss: 121.8738 Epoch 26/30 60000/60000 [==============================] - 1s - loss: 121.8908 - val_loss: 122.0048 Epoch 27/30 60000/60000 [==============================] - 1s - loss: 121.8052 - val_loss: 121.7829 Epoch 28/30 60000/60000 [==============================] - 1s - loss: 121.6288 - val_loss: 121.6008 Epoch 29/30 60000/60000 [==============================] - 1s - loss: 121.6143 - val_loss: 122.2590 Epoch 30/30 60000/60000 [==============================] - 1s - loss: 121.5304 - val_loss: 121.5872
<keras.callbacks.History at 0x7ffb550ddb38>
# build a model to project inputs on the latent space
encoder = Model(x, pi)
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(10, 20))
n = 10
im_index = 1
for i in range(1, n+1):
plt.subplot(n, 2, im_index)
plt.axis('off')
plt.imshow(x_test[i-1].reshape((28,28)))
im_index += 1
plt.subplot(n, 2, im_index)
ax = sns.heatmap(x_test_encoded[i-1])
im_index += 1
plt.savefig('gumbel_vae_mnist_encode.pdf')
plt.show()
# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim*categorical_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)
n = 14
digit_size = 28
figure = np.zeros((digit_size*n, digit_size*n))
# random sampling
for i in range(n):
for j in range(n):
data = np.zeros(categorical_dim*latent_dim)
for index, nonzero_index in enumerate(np.random.randint(0, categorical_dim, latent_dim)):
data[nonzero_index+index*categorical_dim] += 1
z_sample = np.array([data])
x_decoded = generator.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(figure)
plt.savefig('gumbel_vae_mnist_decode.pdf')
plt.show()