For tqdm progress bar to work correctly, before launching this notebook run:
$ jupyter nbextension enable --py --sys-prefix widgetsnbextension
Also for GIF to appear launch jupyter as follows:
$ jupyter notebook --NotebookApp.iopub_data_rate_limit=100000000
Or simply run from project root:
$ make jupyter
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from JSAnimation.IPython_display import display_animation
from sklearn.metrics import accuracy_score, confusion_matrix
import env
from boltzmann_machines import DBM
from boltzmann_machines.rbm import BernoulliRBM
from boltzmann_machines.utils import (progress_bar, Stopwatch,
im_plot, im_reshape, im_gif, tick_params,
plot_confusion_matrix)
from boltzmann_machines.utils.dataset import load_mnist
%matplotlib inline
%load_ext autoreload
%autoreload 2
X, y = load_mnist(mode='train', path='../data/')
X /= 255.
X_test, y_test = load_mnist(mode='test', path='../data/')
X_test /= 255.
print X.shape, y.shape, X_test.shape, y_test.shape
(60000, 784) (60000,) (10000, 784) (10000,)
fig = plt.figure(figsize=(10, 10))
im_plot(X[:100], shape=(28, 28), title='Training examples',
imshow_params={'cmap': plt.cm.gray})
plt.savefig('mnist.png', dpi=196, bbox_inches='tight')
rbm1 = BernoulliRBM.load_model('../models/dbm_mnist_rbm1/')
rbm1_W = rbm1.get_tf_params(scope='weights')['W']
fig = plt.figure(figsize=(10, 10))
im_plot(rbm1_W.T, shape=(28, 28), title='First 100 filters extracted by RBM #1',
imshow_params={'cmap': plt.cm.gray});
plt.savefig('dbm_mnist_rbm1.png', dpi=196, bbox_inches='tight');
rbm2 = BernoulliRBM.load_model('../models/dbm_mnist_rbm2/')
rbm2_W = rbm2.get_tf_params(scope='weights')['W']
U = rbm1_W.dot(rbm2_W)
fig = plt.figure(figsize=(10, 10))
im_plot(U.T, shape=(28, 28), title='First 100 (high-level) filters extracted by RBM #2',
imshow_params={'cmap': plt.cm.gray});
plt.savefig('dbm_mnist_rbm2.png', dpi=196, bbox_inches='tight');
dbm = DBM.load_model('../models/dbm_mnist/')
dbm.load_rbms([rbm1, rbm2]) # !!!
W1_joint = dbm.get_tf_params(scope='weights')['W']
fig = plt.figure(figsize=(10, 10))
im_plot(W1_joint.T, shape=(28, 28), title='First 100 filters of DBM after joint training (1st layer)',
title_params={'fontsize': 20}, imshow_params={'cmap': plt.cm.gray});
plt.savefig('dbm_mnist_W1_joint.png', dpi=196, bbox_inches='tight');
W2_joint = dbm.get_tf_params(scope='weights')['W_1']
U_joint = W1_joint.dot(W2_joint)
fig = plt.figure(figsize=(10, 10))
im_plot(U_joint.T, shape=(28, 28), title='First 100 filters of DBM after joint training (2nd layer)',
title_params={'fontsize': 20}, imshow_params={'cmap': plt.cm.gray});
plt.savefig('dbm_mnist_W2_joint.png', dpi=196, bbox_inches='tight');
with Stopwatch(verbose=True) as s:
V = dbm.sample_v(n_gibbs_steps=1337)
fig = plt.figure(figsize=(10, 10))
im_plot(V, shape=(28, 28), title='Samples generated by DBM after 1337 Gibbs steps',
imshow_params={'cmap': plt.cm.gray});
plt.savefig('dbm_mnist_samples.png', dpi=196, bbox_inches='tight');
Elapsed time: 14.141 sec
(note that new samples override previous particles on disk)
samples = []
for i in progress_bar(range(100)):
V = dbm.sample_v(n_gibbs_steps=100, save_model=True)
Z = im_reshape(V, shape=(28, 28))
samples.append(Z)
A Jupyter Widget
fig = plt.figure(figsize=(6, 6), tight_layout=True)
im = plt.imshow(np.zeros((280, 280)), cmap=plt.cm.gray, animated=True, vmin=0., vmax=1.)
im.axes.tick_params(**tick_params())
anim = im_gif(samples, im, fig, fname='dbm_samples.gif',
title_func=lambda i: 'Samples generated by DBM after {0} Gibbs steps'.format(100 * i),
title_params={'fontsize': 15, 'y': 1.02}, save_params={'dpi': 100})
display_animation(anim)