For tqdm progress bar to work correctly, before launching this notebook run:
$ jupyter nbextension enable --py --sys-prefix widgetsnbextension
Also for GIF to appear launch jupyter as follows:
$ jupyter notebook --NotebookApp.iopub_data_rate_limit=100000000
Or simply run from project root:
$ make jupyter
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import matplotlib.pyplot as plt
from JSAnimation.IPython_display import display_animation
from sklearn.metrics import accuracy_score, confusion_matrix
import env
from boltzmann_machines.dbm import DBM
from boltzmann_machines.rbm import GaussianRBM, MultinomialRBM
from boltzmann_machines.utils import (progress_bar, RNG,
im_plot, im_reshape, im_gif, tick_params, plot_confusion_matrix)
from boltzmann_machines.utils.dataset import (load_cifar10, plot_cifar10, get_cifar10_labels,
im_unflatten, im_rescale)
%matplotlib inline
%load_ext autoreload
%autoreload 2
X, y = load_cifar10(mode='train', path='../data/')
RNG(seed=42).shuffle(X)
RNG(seed=42).shuffle(y)
print X.shape, y.shape
(50000, 3072) (50000,)
plt.figure(figsize=(10, 10))
plot_cifar10(im_unflatten(X[:1000]), y[:1000], samples_per_class=10)
plt.savefig('cifar10.png', dpi=196, bbox_inches='tight')
X_s = np.load('../data/X_s.npy')
print X_s.shape, X_s.dtype
(49000, 3072) float32
plt.figure(figsize=(10, 10))
plot_cifar10(im_rescale(X_s[:1000]), y[:1000],
samples_per_class=10, title='CIFAR-10 smoothed')
plt.savefig('cifar10_smoothed.png', dpi=196, bbox_inches='tight')
# load RBM weights
grbm = GaussianRBM.load_model('../models/grbm_cifar_naive/')
grbm_W = grbm.get_tf_params(scope='weights')['W']
print grbm_W.shape
# prepare for plotting
W = im_rescale( grbm_W.copy().T )
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(W, title='First 100 filters extracted by Gaussian RBM')
plt.savefig('dbm_cifar_naive_grbm.png', dpi=196, bbox_inches='tight');
(3072, 5000)
# load RBM weights
mrbm = MultinomialRBM.load_model('../models/mrbm_cifar_naive/')
mrbm_W = mrbm.get_tf_params(scope='weights')['W']
print mrbm_W.shape
# prepare for plotting
W = grbm_W.dot(mrbm_W)
W = im_rescale( W.T )
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(W, title='First 100 filters extracted by Multinomial RBM')
plt.savefig('dbm_cifar_naive_mrbm.png', dpi=196, bbox_inches='tight');
(5000, 1000)
dbm = DBM.load_model('../models/dbm_cifar_naive/')
dbm.load_rbms([grbm, mrbm]) # !!!
weights = dbm.get_tf_params('weights')
W1 = weights['W']
W2 = weights['W_1']
print W1.shape, W2.shape
(3072, 5000) (5000, 1000)
# prepare for plotting
W = im_rescale( W1.T )
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(W, title='First 100 filters of DBM after joint training (1st layer)',
title_params={'fontsize': 20})
plt.savefig('dbm_cifar_naive_W1_joint.png', dpi=196, bbox_inches='tight');
# prepare for plotting
W = W1.dot(W2)
W = im_rescale( W.T )
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(W, title='First 100 filters of DBM after joint training (2nd layer)',
title_params={'fontsize': 20})
plt.savefig('dbm_cifar_naive_W2_joint.png', dpi=196, bbox_inches='tight');
X_s_mean = np.load('../data/X_s_mean.npy')
X_s_std = np.load('../data/X_s_std.npy')
V = dbm.sample_v(n_gibbs_steps=10)
print V.shape
(100, 3072)
# prepare for plotting
V = im_rescale(V, mean=X_s_mean, std=X_s_std)
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(V, title='Samples generated by DBM after training');
plt.savefig('dbm_cifar_naive_samples.png', dpi=196, bbox_inches='tight');
(note that new samples override previous particles on disk)
V = im_reshape(V, n_width=10, n_height=10)
samples = [V]
for i in progress_bar(range(10)):
V = dbm.sample_v(n_gibbs_steps=50, save_model=True)
V = im_rescale(V, mean=X_s_mean, std=X_s_std)
V = im_reshape(V, n_width=10, n_height=10)
samples.append(V)
A Jupyter Widget
fig = plt.figure(figsize=(6, 6), tight_layout=True)
im = plt.imshow(np.zeros((320, 320, 3), dtype='uint8'), animated=True, vmin=0, vmax=255)
im.axes.tick_params(**tick_params())
anim = im_gif(samples, im, fig, fname='dbm_cifar_naive_samples.gif',
title_func=lambda i: 'Samples generated by DBM after {0} Gibbs steps'.format(50 * i),
title_params={'fontsize': 15, 'y': 1.02}, anim_params={'interval': 500}, save_params={'dpi': 144})
display_animation(anim)
y_pred = np.load('../data/grbm_naive_y_pred.npy')
y_test = np.load('../data/grbm_naive_y_test.npy')
W_finetuned = np.load('../data/grbm_naive_W_finetuned.npy')
print accuracy_score(y_test, y_pred)
0.5978
Notice how infrequently animal is mistaken for a non-animal (or vice versa) and how frequently animal is mistaken to another animal (and the same for non-animal). The matrix therefore has block structure.
C = confusion_matrix(y_test, y_pred)
fig = plt.figure(figsize=(10, 8))
ax = plot_confusion_matrix(C, labels=get_cifar10_labels(), labels_fontsize=12, fmt='d')
plt.title('Confusion matrix for fine-tuned G-RBM\n', fontsize=20, y=0.97)
plt.savefig('dbm_cifar_naive_grbm_confusion_matrix.png', dpi=144, bbox_inches='tight')
# prepare for plotting
W = im_rescale( W_finetuned.T )
# plot
fig = plt.figure(figsize=(10, 10))
im_plot(W, title='First 100 filters of Gaussian RBM after fine-tuning',
title_params={'fontsize': 20})
plt.savefig('dbm_cifar_naive_grbm_finetuned.png', dpi=196, bbox_inches='tight');