# use Matplotlib (don't ask)
import matplotlib.pyplot as plt
import daft
import seaborn as sns
from matplotlib import rc
%matplotlib inline
MNISTのデータをkerasのサイトからダウンロードします。
from keras.datasets import mnist
import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
print x_train.shape
print x_test.shape
Using TensorFlow backend.
(60000, 784) (10000, 784)
from keras.layers import Input, Dense, Lambda, concatenate
from keras.models import Model
from keras import metrics
from keras import backend as K
教師付きデータでVAEを学習させ、潜在変数の分布から指定したカテゴリの画像の復元をします。
from keras.utils import to_categorical
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
n_x = x_train.shape[1]
n_y = y_train.shape[1]
batch_size = 100
original_dim = 28*28
latent_dim = 20
intermediate_dim = 512
epochs = 150
categorical_dim = 10
epsilon_std = 1.0
# encoding
x = Input(shape=(n_x,))
label = Input(shape=(n_y,))
input = concatenate([x, label] , axis=-1)
h = Dense(intermediate_dim, activation='relu', activity_regularizer = 'l2')(input)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
def normal_sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
stddev=epsilon_std)
return z_mean + K.exp(z_log_var / 2) * epsilon
z = Lambda(normal_sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
zc = concatenate([z, label] , axis=-1)
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(zc)
x_decoded_mean = decoder_mean(h_decoded)
# end-to-end autoencoder
cvae = Model([x, label], x_decoded_mean)
# encoder, from inputs to latent space
encoder = Model([x, label], z_mean)
# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim+categorical_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)
def normal_sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
stddev=epsilon_std)
return z_mean + K.exp(z_log_var / 2) * epsilon
def cvae_loss(x, x_decoded_mean):
recon_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(recon_loss + kl_loss )
cvae.compile(optimizer='rmsprop', loss=cvae_loss)
history = cvae.fit([x_train, y_train], x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=([x_test, y_test], x_test))
Train on 60000 samples, validate on 10000 samples Epoch 1/150 60000/60000 [==============================] - 24s - loss: 165.8565 - val_loss: 136.2912 Epoch 2/150 60000/60000 [==============================] - 24s - loss: 131.1071 - val_loss: 127.8767 Epoch 3/150 60000/60000 [==============================] - 24s - loss: 122.3535 - val_loss: 118.4269 Epoch 4/150 60000/60000 [==============================] - 24s - loss: 117.8729 - val_loss: 115.6204 Epoch 5/150 60000/60000 [==============================] - 25s - loss: 115.1012 - val_loss: 113.2396 Epoch 6/150 60000/60000 [==============================] - 24s - loss: 113.2657 - val_loss: 111.4246 Epoch 7/150 60000/60000 [==============================] - 24s - loss: 112.0224 - val_loss: 110.4420 Epoch 8/150 60000/60000 [==============================] - 24s - loss: 111.0929 - val_loss: 109.7930 Epoch 9/150 60000/60000 [==============================] - 24s - loss: 110.3146 - val_loss: 109.3293 Epoch 10/150 60000/60000 [==============================] - 24s - loss: 109.7219 - val_loss: 110.9341 Epoch 11/150 60000/60000 [==============================] - 24s - loss: 109.2341 - val_loss: 108.1405 Epoch 12/150 60000/60000 [==============================] - 24s - loss: 108.7829 - val_loss: 110.1502 Epoch 13/150 60000/60000 [==============================] - 24s - loss: 108.3712 - val_loss: 108.1395 Epoch 14/150 60000/60000 [==============================] - 24s - loss: 108.0780 - val_loss: 107.5192 Epoch 15/150 60000/60000 [==============================] - 25s - loss: 107.7901 - val_loss: 107.9414 Epoch 16/150 60000/60000 [==============================] - 24s - loss: 107.4975 - val_loss: 107.7586 Epoch 17/150 60000/60000 [==============================] - 24s - loss: 107.2607 - val_loss: 107.4573 Epoch 18/150 60000/60000 [==============================] - 25s - loss: 107.0163 - val_loss: 106.9288 Epoch 19/150 60000/60000 [==============================] - 24s - loss: 106.8545 - val_loss: 106.8551 Epoch 20/150 60000/60000 [==============================] - 25s - loss: 106.6248 - val_loss: 107.1314 Epoch 21/150 60000/60000 [==============================] - 25s - loss: 106.4819 - val_loss: 106.2840 Epoch 22/150 60000/60000 [==============================] - 24s - loss: 106.3219 - val_loss: 106.1901 Epoch 23/150 60000/60000 [==============================] - 25s - loss: 106.1955 - val_loss: 106.9743 Epoch 24/150 60000/60000 [==============================] - 25s - loss: 106.0411 - val_loss: 106.3594 Epoch 25/150 60000/60000 [==============================] - 25s - loss: 105.9730 - val_loss: 106.3271 Epoch 26/150 60000/60000 [==============================] - 25s - loss: 105.8230 - val_loss: 105.8715 Epoch 27/150 60000/60000 [==============================] - 25s - loss: 105.7171 - val_loss: 106.9231 Epoch 28/150 60000/60000 [==============================] - 25s - loss: 105.6226 - val_loss: 105.6183 Epoch 29/150 60000/60000 [==============================] - 25s - loss: 105.5065 - val_loss: 106.4121 Epoch 30/150 60000/60000 [==============================] - 25s - loss: 105.4069 - val_loss: 106.3822 Epoch 31/150 60000/60000 [==============================] - 25s - loss: 105.3469 - val_loss: 107.0293 Epoch 32/150 60000/60000 [==============================] - 25s - loss: 105.2754 - val_loss: 105.4671 Epoch 33/150 60000/60000 [==============================] - 25s - loss: 105.2036 - val_loss: 106.9695 Epoch 34/150 60000/60000 [==============================] - 25s - loss: 105.1155 - val_loss: 106.5850 Epoch 35/150 60000/60000 [==============================] - 25s - loss: 105.0779 - val_loss: 105.9295 Epoch 36/150 60000/60000 [==============================] - 25s - loss: 105.0115 - val_loss: 105.0917 Epoch 37/150 60000/60000 [==============================] - 25s - loss: 104.9092 - val_loss: 105.2324 Epoch 38/150 60000/60000 [==============================] - 25s - loss: 104.8821 - val_loss: 106.8289 Epoch 39/150 60000/60000 [==============================] - 28s - loss: 104.8055 - val_loss: 105.7805 Epoch 40/150 60000/60000 [==============================] - 27s - loss: 104.7281 - val_loss: 107.0707 Epoch 41/150 60000/60000 [==============================] - 27s - loss: 104.7096 - val_loss: 106.3260 Epoch 42/150 60000/60000 [==============================] - 27s - loss: 104.6703 - val_loss: 105.2415 Epoch 43/150 60000/60000 [==============================] - 27s - loss: 104.6576 - val_loss: 106.2871 Epoch 44/150 60000/60000 [==============================] - 27s - loss: 104.5640 - val_loss: 105.1583 Epoch 45/150 60000/60000 [==============================] - 27s - loss: 104.5195 - val_loss: 104.9134 Epoch 46/150 60000/60000 [==============================] - 27s - loss: 104.4651 - val_loss: 104.9704 Epoch 47/150 60000/60000 [==============================] - 27s - loss: 104.4870 - val_loss: 105.6294 Epoch 48/150 60000/60000 [==============================] - 27s - loss: 104.4287 - val_loss: 104.9503 Epoch 49/150 60000/60000 [==============================] - 27s - loss: 104.3906 - val_loss: 105.0339 Epoch 50/150 60000/60000 [==============================] - 27s - loss: 104.3558 - val_loss: 104.7701 Epoch 51/150 60000/60000 [==============================] - 27s - loss: 104.2872 - val_loss: 104.9319 Epoch 52/150 60000/60000 [==============================] - 27s - loss: 104.2768 - val_loss: 104.8593 Epoch 53/150 60000/60000 [==============================] - 28s - loss: 104.2370 - val_loss: 105.0190 Epoch 54/150 60000/60000 [==============================] - 28s - loss: 104.2514 - val_loss: 104.6350 Epoch 55/150 60000/60000 [==============================] - 28s - loss: 104.2048 - val_loss: 104.5828 Epoch 56/150 60000/60000 [==============================] - 27s - loss: 104.1685 - val_loss: 106.3544 Epoch 57/150 60000/60000 [==============================] - 28s - loss: 104.1285 - val_loss: 105.2946 Epoch 58/150 60000/60000 [==============================] - 28s - loss: 104.1076 - val_loss: 104.7409 Epoch 59/150 60000/60000 [==============================] - 28s - loss: 104.0862 - val_loss: 105.1635 Epoch 60/150 60000/60000 [==============================] - 28s - loss: 104.0981 - val_loss: 105.8210 Epoch 61/150 60000/60000 [==============================] - 26s - loss: 104.0834 - val_loss: 106.4702 Epoch 62/150 60000/60000 [==============================] - 29s - loss: 104.0361 - val_loss: 104.6543 Epoch 63/150 60000/60000 [==============================] - 30s - loss: 104.0235 - val_loss: 105.1734 Epoch 64/150 60000/60000 [==============================] - 32s - loss: 104.0113 - val_loss: 104.8485 Epoch 65/150 60000/60000 [==============================] - 28s - loss: 103.9925 - val_loss: 105.0373 Epoch 66/150 60000/60000 [==============================] - 28s - loss: 103.9722 - val_loss: 105.0731 Epoch 67/150 60000/60000 [==============================] - 30s - loss: 103.9224 - val_loss: 104.5775 Epoch 68/150 60000/60000 [==============================] - 27s - loss: 103.9400 - val_loss: 105.3824 Epoch 69/150 60000/60000 [==============================] - 28s - loss: 103.9063 - val_loss: 104.1367 Epoch 70/150 60000/60000 [==============================] - 28s - loss: 103.8601 - val_loss: 104.7196 Epoch 71/150 60000/60000 [==============================] - 27s - loss: 103.8630 - val_loss: 105.3893 Epoch 72/150 60000/60000 [==============================] - 30s - loss: 103.8386 - val_loss: 104.1043 Epoch 73/150 60000/60000 [==============================] - 26s - loss: 103.8655 - val_loss: 104.6404 Epoch 74/150 60000/60000 [==============================] - 26s - loss: 103.8300 - val_loss: 105.1221 Epoch 75/150 60000/60000 [==============================] - 26s - loss: 103.7700 - val_loss: 104.9232 Epoch 76/150 60000/60000 [==============================] - 27s - loss: 103.8180 - val_loss: 104.8616 Epoch 77/150 60000/60000 [==============================] - 27s - loss: 103.8065 - val_loss: 105.6343 Epoch 78/150 60000/60000 [==============================] - 26s - loss: 103.7868 - val_loss: 104.4001 Epoch 79/150 60000/60000 [==============================] - 26s - loss: 103.7582 - val_loss: 104.0777 Epoch 80/150 60000/60000 [==============================] - 26s - loss: 103.7350 - val_loss: 104.8156 Epoch 81/150 60000/60000 [==============================] - 27s - loss: 103.7227 - val_loss: 104.4828 Epoch 82/150 60000/60000 [==============================] - 28s - loss: 103.7349 - val_loss: 105.3242 Epoch 83/150 60000/60000 [==============================] - 28s - loss: 103.7419 - val_loss: 104.7343 Epoch 84/150 60000/60000 [==============================] - 27s - loss: 103.6938 - val_loss: 104.8098 Epoch 85/150 60000/60000 [==============================] - 33s - loss: 103.7256 - val_loss: 104.1060 Epoch 86/150 60000/60000 [==============================] - 27s - loss: 103.6841 - val_loss: 106.4758 Epoch 87/150 60000/60000 [==============================] - 27s - loss: 103.7247 - val_loss: 104.0544 Epoch 88/150 60000/60000 [==============================] - 27s - loss: 103.6802 - val_loss: 104.6736 Epoch 89/150 60000/60000 [==============================] - 27s - loss: 103.6713 - val_loss: 104.3651 Epoch 90/150 60000/60000 [==============================] - 28s - loss: 103.6658 - val_loss: 106.5210 Epoch 91/150 60000/60000 [==============================] - 27s - loss: 103.6660 - val_loss: 104.9781 Epoch 92/150 60000/60000 [==============================] - 29s - loss: 103.6225 - val_loss: 104.5751 Epoch 93/150 60000/60000 [==============================] - 27s - loss: 103.6092 - val_loss: 104.8122 Epoch 94/150 60000/60000 [==============================] - 28s - loss: 103.5607 - val_loss: 104.4559 Epoch 95/150 60000/60000 [==============================] - 27s - loss: 103.6147 - val_loss: 104.2416 Epoch 96/150 60000/60000 [==============================] - 28s - loss: 103.5956 - val_loss: 105.1973 Epoch 97/150 60000/60000 [==============================] - 28s - loss: 103.5939 - val_loss: 104.3007 Epoch 98/150 60000/60000 [==============================] - 28s - loss: 103.5896 - val_loss: 104.9037 Epoch 99/150 60000/60000 [==============================] - 28s - loss: 103.5909 - val_loss: 104.2742 Epoch 100/150 60000/60000 [==============================] - 30s - loss: 103.5558 - val_loss: 104.3250 Epoch 101/150 60000/60000 [==============================] - 28s - loss: 103.5829 - val_loss: 104.7833 Epoch 102/150 60000/60000 [==============================] - 28s - loss: 103.5245 - val_loss: 104.8388 Epoch 103/150 60000/60000 [==============================] - 29s - loss: 103.5307 - val_loss: 104.4639 Epoch 104/150 60000/60000 [==============================] - 27s - loss: 103.5586 - val_loss: 104.0766 Epoch 105/150 60000/60000 [==============================] - 27s - loss: 103.5152 - val_loss: 104.7443 Epoch 106/150 60000/60000 [==============================] - 27s - loss: 103.5671 - val_loss: 104.7494 Epoch 107/150 60000/60000 [==============================] - 27s - loss: 103.5247 - val_loss: 104.4805 Epoch 108/150 60000/60000 [==============================] - 28s - loss: 103.4976 - val_loss: 105.2901 Epoch 109/150 60000/60000 [==============================] - 29s - loss: 103.4886 - val_loss: 104.1706 Epoch 110/150 60000/60000 [==============================] - 27s - loss: 103.5402 - val_loss: 104.5410 Epoch 111/150 60000/60000 [==============================] - 26s - loss: 103.4912 - val_loss: 104.2916 Epoch 112/150 60000/60000 [==============================] - 27s - loss: 103.4970 - val_loss: 104.5997 Epoch 113/150 60000/60000 [==============================] - 27s - loss: 103.5499 - val_loss: 104.8315 Epoch 114/150 60000/60000 [==============================] - 27s - loss: 103.4839 - val_loss: 104.8615 Epoch 115/150 60000/60000 [==============================] - 29s - loss: 103.5217 - val_loss: 104.3486 Epoch 116/150 60000/60000 [==============================] - 28s - loss: 103.4984 - val_loss: 104.9692 Epoch 117/150 60000/60000 [==============================] - 27s - loss: 103.4813 - val_loss: 104.1119 Epoch 118/150 60000/60000 [==============================] - 27s - loss: 103.4638 - val_loss: 104.0927 Epoch 119/150 60000/60000 [==============================] - 27s - loss: 103.4890 - val_loss: 104.4658 Epoch 120/150 60000/60000 [==============================] - 27s - loss: 103.4696 - val_loss: 104.8914 Epoch 121/150 60000/60000 [==============================] - 27s - loss: 103.4476 - val_loss: 105.1364 Epoch 122/150 60000/60000 [==============================] - 28s - loss: 103.4505 - val_loss: 104.5902 Epoch 123/150 60000/60000 [==============================] - 27s - loss: 103.4835 - val_loss: 104.4458 Epoch 124/150 60000/60000 [==============================] - 27s - loss: 103.4723 - val_loss: 105.2298 Epoch 125/150 60000/60000 [==============================] - 27s - loss: 103.4812 - val_loss: 105.0418 Epoch 126/150 60000/60000 [==============================] - 30s - loss: 103.4712 - val_loss: 104.1334 Epoch 127/150 60000/60000 [==============================] - 28s - loss: 103.4644 - val_loss: 104.1777 Epoch 128/150 60000/60000 [==============================] - 28s - loss: 103.4313 - val_loss: 104.3415 Epoch 129/150 60000/60000 [==============================] - 27s - loss: 103.4533 - val_loss: 104.7715 Epoch 130/150 60000/60000 [==============================] - 28s - loss: 103.4102 - val_loss: 104.7580 Epoch 131/150 60000/60000 [==============================] - 28s - loss: 103.4565 - val_loss: 104.8503 Epoch 132/150 60000/60000 [==============================] - 28s - loss: 103.4663 - val_loss: 104.1593 Epoch 133/150 60000/60000 [==============================] - 28s - loss: 103.4445 - val_loss: 104.5614 Epoch 134/150 60000/60000 [==============================] - 28s - loss: 103.4559 - val_loss: 105.2657 Epoch 135/150 60000/60000 [==============================] - 28s - loss: 103.4265 - val_loss: 104.3437 Epoch 136/150 60000/60000 [==============================] - 28s - loss: 103.4208 - val_loss: 104.4579 Epoch 137/150 60000/60000 [==============================] - 28s - loss: 103.4111 - val_loss: 105.5145 Epoch 138/150 60000/60000 [==============================] - 28s - loss: 103.4125 - val_loss: 104.9128 Epoch 139/150 60000/60000 [==============================] - 29s - loss: 103.3964 - val_loss: 104.1655 Epoch 140/150 60000/60000 [==============================] - 28s - loss: 103.4433 - val_loss: 105.8056 Epoch 141/150 60000/60000 [==============================] - 28s - loss: 103.3969 - val_loss: 105.6866 Epoch 142/150 60000/60000 [==============================] - 28s - loss: 103.3769 - val_loss: 104.4300 Epoch 143/150 60000/60000 [==============================] - 34s - loss: 103.3967 - val_loss: 104.4618 Epoch 144/150 60000/60000 [==============================] - 48s - loss: 103.3570 - val_loss: 104.1320 Epoch 145/150 60000/60000 [==============================] - 35s - loss: 103.4004 - val_loss: 104.5835 Epoch 146/150 60000/60000 [==============================] - 34s - loss: 103.3988 - val_loss: 104.8078 Epoch 147/150 60000/60000 [==============================] - 37s - loss: 103.4044 - val_loss: 105.5785 Epoch 148/150 60000/60000 [==============================] - 33s - loss: 103.4488 - val_loss: 104.9559 Epoch 149/150 60000/60000 [==============================] - 34s - loss: 103.4197 - val_loss: 104.0929 Epoch 150/150 60000/60000 [==============================] - 31s - loss: 103.3924 - val_loss: 104.3723
x_test_encoded = encoder.predict([x_test, y_test], batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=np.argmax(y_test, axis=1), cmap="plasma")
plt.colorbar()
plt.show()
plt.figure(figsize=(8, 16))
n = 10
im_index = 1
indies = range(latent_dim)
for i in range(0, n):
plt.subplot(n, 2, im_index)
plt.axis('off')
plt.imshow(x_test[i].reshape((28,28)))
im_index += 1
plt.subplot(n, 2, im_index)
ax = plt.bar(indies, x_test_encoded[i])
im_index += 1
plt.show()
def get_one_hot_vector(idx, dim=10):
one_hot = np.zeros(dim)
one_hot[idx] = 1.
return one_hot
z_predict = encoder.predict([x_test[3].reshape((1, 784)), y_test[3].reshape((1, 10))])
n = 1
plt.figure(figsize=(8, 8))
for j in range(10):
plt.subplot(1, 10, n); n += 1; plt.axis('off')
_c_sample = get_one_hot_vector(j, categorical_dim)
latent_sample = np.hstack((z_predict[0], _c_sample))
generated = generator.predict(np.array([latent_sample]))
digit = generated[0].reshape(28, 28)
plt.imshow(digit)
plt.show()
from keras.preprocessing import image
categories = [2, 3, 5, 7]
images = [ "images/%d.png" % i for i in categories]
for i in categories:
img_path = "images/%d.png" % i
img = image.load_img(img_path, target_size=(28, 28)) # Read the image and resize it
x_data = x_data = ((255 - image.img_to_array(img))/255)[:, :, 0].reshape((1, 784))
c_data = get_one_hot_vector(i, categorical_dim).reshape((1, 10))
z_predict = encoder.predict([x_data, c_data])
# print z_predict
n = 1
plt.figure(figsize=(8, 8))
plt.subplot(len(categories), 11, n); n += 1; plt.axis('off')
plt.imshow(img)
for j in range(10):
plt.subplot(len(categories), 11, n); n += 1; plt.axis('off')
_c_sample = get_one_hot_vector(j, categorical_dim)
latent_sample = np.hstack((z_predict[0], _c_sample))
generated = generator.predict(np.array([latent_sample]))
digit = generated[0].reshape(28, 28)
plt.imshow(digit)
plt.show()