import numpy as np
import keras
from keras.layers import Dense, GlobalAveragePooling2D
Using TensorFlow backend.
from sklearn.datasets import fetch_olivetti_faces
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('classic')
data_obj = fetch_olivetti_faces()
X = data_obj['data']
labels = data_obj['target']
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, labels, stratify=labels, test_size=0.1, random_state=0)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)
X_test = scaler.transform(X_test_raw)
plt.imshow(X_train[0].reshape(64, 64), cmap='gray');
import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np
Variational Autoencoders (VAE) are a class of generative models.
Unsupervised generative models work with data likelihood, $P(x)$.
VAEs are Bayesian models which use Variational Inference - they work by approximating
$$ P(x) = \int{P(x|z)P(z) dz} $$For some random variable $Z$ corresponding to $P(z)$. $Z$ is also called latent variable.
VAEs use $Q(z)$ to approximate $P(z)$ - by choosing a simple distribution on latent space - usually Gaussian.
The whole complexity of $P(x)$ is factored out as simple distribution $Q(z)$ and decoder from latent space that is a neural network - it defines $P(x|z)$.
Recall Kullback-Leibler divergence.
$$KL(Q, P(\cdot | X)) = \int Q(z) log \frac{Q(z)}{P(z|x)} dz$$Rewriting it further $$KL(Q, P(\cdot | X)) = \int Q(z) log Q(z)dz- \int Q(z) log P(z|x) dz$$
$$ = \int Q(z) log Q(z)dz - \int Q(z) log \frac{P(x, z)}{P(x)} dz$$$$ = \int Q(z) log Q(z)dz - \int Q(z) log P(x, z)dz + \int Q(z)log P(x)dz $$$$ = \int Q(z) (log Q(z) - log P(x, z))dz + log P(x)$$So
$$ log P(x) = KL(Q, P(\cdot | X)) - \int Q(z) log Q(z)dz + \int Q(z) log P(x, z)dz $$By noting that the last term is nonnegative we obtain
$$log P(x) \geq KL(Q, P(\cdot | X)) - \int Q(z) log Q(z)dz $$This fact is called Variational Lower Bound.
VAEs work by optimizing right hand of this inequality, thus approximately modelling $P(x)$.
def encode(input_img, latent_dim):
x = layers.Conv2D(32, 3,
padding='same', activation='relu')(input_img)
x = layers.Conv2D(32, 3,
padding='same', activation='relu',
strides=(2, 2))(x)
x = layers.Conv2D(32, 3,
padding='same', activation='relu')(x)
x = layers.Conv2D(16, 3,
padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)
x = layers.Flatten()(x)
code = layers.Dense(16, activation='relu')(x)
return code, shape_before_flattening
def make_hidden_variable(latent_variable_input):
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
mean=0., stddev=1.)
return z_mean + K.exp(z_log_var) * epsilon
z_mean = layers.Dense(latent_dim)(latent_variable_input)
z_log_var = layers.Dense(latent_dim)(latent_variable_input)
sampled_hidden = layers.Lambda(sampling)([z_mean, z_log_var])
return sampled_hidden, z_mean, z_log_var
def make_decoder(z, shape_before_flattening):
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),
activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(16, 3,
padding='same',
activation='relu',
strides=(2, 2))(x)
x = layers.Conv2D(1, 3,
padding='same',
activation='sigmoid')(x)
return Model(decoder_input, x)
def vae_loss(x, x_reconstructed, z_mean, z_log_var, kl_beta, reconstruction_loss):
original_dim = np.prod(K.int_shape(x)[1:])
x = K.flatten(x)
x_reconstructed = K.flatten(x_reconstructed)
reconstruction_loss = reconstruction_loss(x, x_reconstructed)
kl_loss = - kl_beta / original_dim * K.sum(
latent_dim + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(reconstruction_loss + kl_loss)
def make_vae(input_shape, latent_dim, kl_beta=0.5, reconstruction_loss=keras.metrics.mean_squared_error):
input_img = keras.Input(shape=input_shape)
code, shape_before_flattening = encode(input_img, latent_dim)
sampled_hidden, z_mean, z_log_var = make_hidden_variable(code)
decoder = make_decoder(sampled_hidden, shape_before_flattening)
x_reconstructed = decoder(sampled_hidden)
vae = Model(input_img, x_reconstructed)
vae.add_loss(vae_loss(input_img, x_reconstructed, z_mean, z_log_var, kl_beta=kl_beta, reconstruction_loss=reconstruction_loss))
return vae, decoder
img_shape = (64, 64, 1)
batch_size = 16
latent_dim = 32
vae, decoder = make_vae(img_shape, latent_dim)
optimizer = keras.optimizers.Adam(beta_1=0.001)
vae.compile(optimizer=optimizer, loss=None)
vae.summary()
X_train = X_train.reshape(-1, *img_shape)
X_test = X_test.reshape(-1, *img_shape)
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_11 (InputLayer) (None, 64, 64, 1) 0 __________________________________________________________________________________________________ conv2d_26 (Conv2D) (None, 64, 64, 32) 320 input_11[0][0] __________________________________________________________________________________________________ conv2d_27 (Conv2D) (None, 32, 32, 32) 9248 conv2d_26[0][0] __________________________________________________________________________________________________ conv2d_28 (Conv2D) (None, 32, 32, 32) 9248 conv2d_27[0][0] __________________________________________________________________________________________________ conv2d_29 (Conv2D) (None, 32, 32, 16) 4624 conv2d_28[0][0] __________________________________________________________________________________________________ flatten_6 (Flatten) (None, 16384) 0 conv2d_29[0][0] __________________________________________________________________________________________________ dense_21 (Dense) (None, 16) 262160 flatten_6[0][0] __________________________________________________________________________________________________ dense_22 (Dense) (None, 32) 544 dense_21[0][0] __________________________________________________________________________________________________ dense_23 (Dense) (None, 32) 544 dense_21[0][0] __________________________________________________________________________________________________ lambda_6 (Lambda) (None, 32) 0 dense_22[0][0] dense_23[0][0] __________________________________________________________________________________________________ model_9 (Model) (None, 64, 64, 1) 543137 lambda_6[0][0] ================================================================================================== Total params: 829,825 Trainable params: 829,825 Non-trainable params: 0 __________________________________________________________________________________________________
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5)
history = vae.fit(x=X_train, y=None,
shuffle=True,
epochs=100,
batch_size=8,
validation_data=(X_test, None),
callbacks=[early_stopping])
Train on 360 samples, validate on 40 samples Epoch 1/100 360/360 [==============================] - 2s 4ms/step - loss: 13.1456 - val_loss: 0.7914 Epoch 2/100 360/360 [==============================] - 1s 2ms/step - loss: 0.8095 - val_loss: 0.7528 Epoch 3/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7778 - val_loss: 0.7371 Epoch 4/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7613 - val_loss: 0.7233 Epoch 5/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7551 - val_loss: 0.7100 Epoch 6/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7352 - val_loss: 0.6946 Epoch 7/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7218 - val_loss: 0.6816 Epoch 8/100 360/360 [==============================] - 1s 2ms/step - loss: 0.7115 - val_loss: 0.6683 Epoch 9/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6999 - val_loss: 0.6676 Epoch 10/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6977 - val_loss: 0.6593 Epoch 11/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6880 - val_loss: 0.6531 Epoch 12/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6849 - val_loss: 0.6564 Epoch 13/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6758 - val_loss: 0.6520 Epoch 14/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6753 - val_loss: 0.6456 Epoch 15/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6704 - val_loss: 0.6531 Epoch 16/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6612 - val_loss: 0.6373 Epoch 17/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6568 - val_loss: 0.6290 Epoch 18/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6563 - val_loss: 0.6381 Epoch 19/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6476 - val_loss: 0.6301 Epoch 20/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6433 - val_loss: 0.6261 Epoch 21/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6389 - val_loss: 0.6211 Epoch 22/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6360 - val_loss: 0.6196 Epoch 23/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6325 - val_loss: 0.6170 Epoch 24/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6287 - val_loss: 0.6179 Epoch 25/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6266 - val_loss: 0.6152 Epoch 26/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6246 - val_loss: 0.6158 Epoch 27/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6210 - val_loss: 0.6117 Epoch 28/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6201 - val_loss: 0.6154 Epoch 29/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6165 - val_loss: 0.6098 Epoch 30/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6180 - val_loss: 0.6214 Epoch 31/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6154 - val_loss: 0.6086 Epoch 32/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6116 - val_loss: 0.6163 Epoch 33/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6082 - val_loss: 0.6090 Epoch 34/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6072 - val_loss: 0.6071 Epoch 35/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6083 - val_loss: 0.6137 Epoch 36/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6064 - val_loss: 0.6092 Epoch 37/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6038 - val_loss: 0.6037 Epoch 38/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6003 - val_loss: 0.6022 Epoch 39/100 360/360 [==============================] - 1s 2ms/step - loss: 0.6018 - val_loss: 0.6077 Epoch 40/100 360/360 [==============================] - 1s 2ms/step - loss: 0.5985 - val_loss: 0.6081 Epoch 41/100 360/360 [==============================] - 1s 2ms/step - loss: 0.5982 - val_loss: 0.6057 Epoch 42/100 360/360 [==============================] - 1s 2ms/step - loss: 0.5957 - val_loss: 0.6094 Epoch 43/100 360/360 [==============================] - 1s 2ms/step - loss: 0.5989 - val_loss: 0.6027
plt.plot(history.history['loss'][1:], label='loss')
plt.plot(history.history['val_loss'][1:], label='val_loss')
plt.legend()
plt.show()
v = np.random.randn(1,latent_dim)
random_image = decoder.predict(v)
plt.imshow(scaler.inverse_transform(random_image[:1, :, :, 0].reshape(1, 4096))[0].reshape((64, 64)), cmap='gray')
<matplotlib.image.AxesImage at 0x7f927407bbe0>