import keras
from keras.models import Sequential, Model, load_model
import os
import pickle
import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.io as spio
import matplotlib.pyplot as plt
from scrambler.models import *
from mnist_utils import load_mnist_data, load_mnist_predictor, animate_mnist_examples
Using TensorFlow backend.
#Load MNIST data
x_train, y_train, x_test, y_test = load_mnist_data()
#Load MNIST predictor
predictor_path = 'saved_models/mnist_binarized_cnn_10_digits.h5'
predictor = load_mnist_predictor(predictor_path)
n train samples = 11552 n test samples = 1888 x_train.shape = (11552, 28, 28, 2)
#Calculate background
pseudo_count = 1.0
x_mean = (np.sum(x_train, axis=0) + pseudo_count) / (x_train.shape[0] + 2. * pseudo_count)
#Visualize background image distribution
f = plt.figure(figsize=(4, 4))
plot_ix = 0
plt.imshow(x_mean[..., 1], cmap="Greys", vmin=0.0, vmax=1.0, aspect='equal')
plt.xticks([], [])
plt.yticks([], [])
plt.tight_layout()
plt.show()
#Calculate mean training set kl-divergence against background
x_train_clipped = np.clip(np.copy(x_train), 1e-8, 1. - 1e-8)
kl_divs = np.sum(x_train_clipped * np.log(x_train_clipped / np.tile(np.expand_dims(x_mean, axis=0), (x_train_clipped.shape[0], 1, 1, 1))), axis=-1) / np.log(2.0)
x_mean_kl_divs = np.mean(kl_divs, axis=(1, 2))
x_mean_kl_div = np.mean(x_mean_kl_divs)
print("Mean KL Div against background (bits) = " + str(x_mean_kl_div))
Mean KL Div against background (bits) = 0.36095130083827087
#Define image template
onehot_template = np.zeros(x_mean.shape)
#Build scrambler
#Scrambler network configuration
network_config = {
'n_groups' : 5,
'n_resblocks_per_group' : 4,
'n_channels' : 32,
'window_size' : 3,
'dilation_rates' : [1, 2, 4, 2, 1],
'drop_rate' : 0.0,
'norm_mode' : 'instance',
'mask_smoothing' : False,
'mask_smoothing_window_size' : 7,
'mask_smoothing_std' : 1.5,
'mask_drop_scales' : [1, 5],
'mask_min_drop_rate' : 0.0,
'mask_max_drop_rate' : 0.5,
'label_input' : False
}
#Initialize scrambler
scrambler = Scrambler(
scrambler_mode='inclusion',
n_classes=10,
input_size_x=28,
input_size_y=28,
n_out_channels=2,
input_templates=[onehot_template],
input_backgrounds=[x_mean],
batch_size=32,
n_samples=32,
sample_mode='st',
zeropad_input=False,
mask_dropout=False,
network_config=network_config
)
#Train scrambler
n_epochs = 10
train_history = scrambler.train(
predictor,
x_train,
y_train,
x_test,
y_test,
n_epochs,
monitor_test_indices=np.arange(32).tolist(),
monitor_batch_freq_dict={0 : 1, 100 : 5, 361 : 10},
nll_mode='reconstruction',
predictor_task='classification',
entropy_mode='target',
entropy_bits=0.005,
entropy_weight=1200.
)
Train on 11552 samples, validate on 1888 samples Epoch 1/10 11552/11552 [==============================] - 143s 12ms/step - loss: 1.8809 - nll_loss: 1.4860 - entropy_loss: 0.3949 - val_loss: 0.7043 - val_nll_loss: 0.5477 - val_entropy_loss: 0.1567 Epoch 2/10 11552/11552 [==============================] - 108s 9ms/step - loss: 1.0446 - nll_loss: 0.9215 - entropy_loss: 0.1231 - val_loss: 0.6175 - val_nll_loss: 0.4856 - val_entropy_loss: 0.1319 Epoch 3/10 11552/11552 [==============================] - 108s 9ms/step - loss: 0.9525 - nll_loss: 0.8293 - entropy_loss: 0.1232 - val_loss: 0.5618 - val_nll_loss: 0.4362 - val_entropy_loss: 0.1256 Epoch 4/10 11552/11552 [==============================] - 108s 9ms/step - loss: 0.9013 - nll_loss: 0.7788 - entropy_loss: 0.1225 - val_loss: 0.5944 - val_nll_loss: 0.5258 - val_entropy_loss: 0.0686 Epoch 5/10 11552/11552 [==============================] - 108s 9ms/step - loss: 0.8915 - nll_loss: 0.7723 - entropy_loss: 0.1192 - val_loss: 0.5517 - val_nll_loss: 0.4659 - val_entropy_loss: 0.0858 Epoch 6/10 11552/11552 [==============================] - 106s 9ms/step - loss: 0.8749 - nll_loss: 0.7592 - entropy_loss: 0.1157 - val_loss: 0.5414 - val_nll_loss: 0.4518 - val_entropy_loss: 0.0896 Epoch 7/10 11552/11552 [==============================] - 106s 9ms/step - loss: 0.8577 - nll_loss: 0.7436 - entropy_loss: 0.1141 - val_loss: 0.4959 - val_nll_loss: 0.3410 - val_entropy_loss: 0.1550 Epoch 8/10 11552/11552 [==============================] - 106s 9ms/step - loss: 0.8328 - nll_loss: 0.7212 - entropy_loss: 0.1116 - val_loss: 0.4752 - val_nll_loss: 0.2978 - val_entropy_loss: 0.1774 Epoch 9/10 11552/11552 [==============================] - 106s 9ms/step - loss: 0.7957 - nll_loss: 0.6845 - entropy_loss: 0.1112 - val_loss: 0.4885 - val_nll_loss: 0.4078 - val_entropy_loss: 0.0807 Epoch 10/10 11552/11552 [==============================] - 106s 9ms/step - loss: 0.7919 - nll_loss: 0.6830 - entropy_loss: 0.1089 - val_loss: 0.4540 - val_nll_loss: 0.3441 - val_entropy_loss: 0.1099
#Save scrambler checkpoint
save_dir = 'saved_models'
model_name = 'mnist_inclusion_scrambler_target_bits_0005_epochs_10'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.save_model(model_path)
pickle.dump({'train_history' : train_history}, open(save_dir + '/' + model_name + '_train_history.pickle', 'wb'))
Saved scrambler model at saved_models/mnist_inclusion_scrambler_target_bits_0005_epochs_10.h5
#Load models
save_dir = 'saved_models'
model_name = 'mnist_inclusion_scrambler_target_bits_0005_epochs_10'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.load_model(model_path)
Loaded scrambler model from saved_models/mnist_inclusion_scrambler_target_bits_0005_epochs_10.h5
#Plot training statistics
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(2 * 4, 3))
n_epochs_actual = len(train_history['nll_loss'])
ax1.plot(np.arange(1, n_epochs_actual + 1), train_history['nll_loss'], linewidth=3, color='green')
ax1.plot(np.arange(1, n_epochs_actual + 1), train_history['val_nll_loss'], linewidth=3, color='orange')
plt.sca(ax1)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("NLL", fontsize=14)
plt.xlim(1, n_epochs_actual)
plt.xticks([1, n_epochs_actual], [1, n_epochs_actual], fontsize=12)
plt.yticks(fontsize=12)
ax2.plot(np.arange(1, n_epochs_actual + 1), train_history['entropy_loss'], linewidth=3, color='green')
ax2.plot(np.arange(1, n_epochs_actual + 1), train_history['val_entropy_loss'], linewidth=3, color='orange')
plt.sca(ax2)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("Entropy Loss", fontsize=14)
plt.xlim(1, n_epochs_actual)
plt.xticks([1, n_epochs_actual], [1, n_epochs_actual], fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
plt.show()
#Interpret the test set using the trained scrambler
image_test, sample_test, importance_scores_test = scrambler.interpret(x_test)
1888/1888 [==============================] - 4s 2ms/step
#Visualize a few reconstructed images
plot_examples = np.arange(10).tolist()
save_examples = []
for test_ix in plot_examples :
print("Test image " + str(test_ix) + ":")
save_figs = False
if save_examples is not None and test_ix in save_examples :
save_figs = True
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(1.5 * 4, 1.5))
ax1.imshow(x_test[test_ix, :, :, 1], cmap="Greys", vmin=0.0, vmax=1.0, aspect='equal')
plt.sca(ax1)
plt.xticks([], [])
plt.yticks([], [])
ax2.imshow(image_test[test_ix, :, :, 1], cmap="Greys", vmin=0.0, vmax=1.0, aspect='equal')
plt.sca(ax2)
plt.xticks([], [])
plt.yticks([], [])
ax3.imshow(importance_scores_test[test_ix, :, :, 0], cmap="hot", vmin=0.0, vmax=np.max(importance_scores_test[test_ix, :, :, 0]), aspect='equal')
plt.sca(ax3)
plt.xticks([], [])
plt.yticks([], [])
ax4.imshow(x_test[test_ix, :, :, 1], cmap="Greys", vmin=0.0, vmax=1.0, aspect='equal')
plt.sca(ax4)
plt.xticks([], [])
plt.yticks([], [])
ax4.imshow(importance_scores_test[test_ix, :, :, 0], alpha=0.75, cmap="hot", vmin=0.0, vmax=np.max(importance_scores_test[test_ix, :, :, 0]), aspect='equal')
plt.sca(ax4)
plt.xticks([], [])
plt.yticks([], [])
plt.tight_layout()
if save_figs :
plt.savefig(model_name + "_test_ix_" + str(test_ix) + ".png", transparent=True, dpi=300)
plt.savefig(model_name + "_test_ix_" + str(test_ix) + ".eps")
plt.show()
Test image 0:
Test image 1:
Test image 2:
Test image 3:
Test image 4:
Test image 5:
Test image 6:
Test image 7:
Test image 8:
Test image 9: