import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import scprep
import m_phate
import m_phate.train
import m_phate.data
Using TensorFlow backend.
Let's set the session so we can limit what percentage of the GPU is used, for courtesy's sake.
sess = tf.InteractiveSession(config=m_phate.train.build_config(limit_gpu_fraction=0.2))
Now we load MNIST in preprocessed form.
x_train, x_test, y_train, y_test = m_phate.data.load_mnist()
# input layer
inputs = tf.keras.layers.Input(
shape=(x_train.shape[1],), dtype='float32', name='inputs')
# three dense hidden layers
h1 = tf.keras.layers.Dense(64, name='h1')(inputs)
h2 = tf.keras.layers.Dense(64, name='h2')(tf.nn.leaky_relu(h1))
h3 = tf.keras.layers.Dense(64, name='h3')(tf.nn.leaky_relu(h2))
# output layer
logits = tf.keras.layers.Dense(10, name='outputs')((tf.nn.leaky_relu(h3)))
# true labels
labels = tf.keras.layers.Input(
shape=(y_train.shape[1],), dtype='float32', name='labels')
# cross entropy loss
loss_op = tf.losses.softmax_cross_entropy(labels, logits)
# categorical accuracy
_, accuracy_op = tf.metrics.accuracy(tf.argmax(labels, axis=1), tf.argmax(logits, axis=1))
# build the optimizer
optimizer = tf.train.AdamOptimizer(1e-5)
train_op = optimizer.minimize(loss_op)
# initialize variables
_ = sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
WARNING:tensorflow:From /usr/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. WARNING:tensorflow:From /usr/lib/python3.7/site-packages/tensorflow/python/ops/losses/losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead.
Here we randomly select ten examples of every digit from the test set - the hidden activations of these images will form our data tensor to be visualized.
np.random.seed(42)
# select the digit indices in the test set
trace_idx = []
for i in range(10):
trace_idx.append(np.random.choice(np.argwhere(
y_test[:, i] == 1).flatten(), 10, replace=False))
trace_idx = np.concatenate(trace_idx)
# extract the selected images
x_trace = x_test[trace_idx]
# create a dictionary in which to store trace, loss, accuracy etc
history = {'trace':[], 'loss':[], 'val_loss':[], 'accuracy':[], 'val_accuracy':[]}
# train!
batch_size = 128
epochs = 200
verbose = False
for epoch in range(epochs):
# randomize batch order
batch_idx = np.random.choice(x_train.shape[0], x_train.shape[0], replace=False)
# run minibatch training
epoch_loss = []
epoch_acc = []
for batch in range(x_train.shape[0] // batch_size):
# get one batch worth of data
x_batch = x_train[batch_idx[batch * batch_size:(batch + 1) * batch_size]]
y_batch = y_train[batch_idx[batch * batch_size:(batch + 1) * batch_size]]
# gradient step
_, batch_loss, batch_acc = sess.run([train_op, loss_op, accuracy_op],
feed_dict={inputs : x_batch, labels : y_batch})
# store loss and accuracy
epoch_loss.append(batch_loss)
epoch_acc.append(batch_acc)
# mean loss and accuracy for the epoch
epoch_loss = np.mean(epoch_loss)
epoch_acc = np.mean(epoch_acc)
history['loss'].append(epoch_loss)
history['accuracy'].append(epoch_acc)
# validation loss and accuracy
epoch_val_loss, epoch_val_acc = sess.run([loss_op, accuracy_op],
feed_dict={inputs : x_test, labels : y_test})
history['val_loss'].append(epoch_val_loss)
history['val_accuracy'].append(epoch_val_acc)
if verbose:
print("Epoch {}/{}:\tloss={:.3f}\tval_loss={:.3f}\tacc={:.3f}\tval_acc={:.3f}".format(
epoch, epochs, epoch_loss, epoch_val_loss, epoch_acc, epoch_val_acc))
# store the activations of the trace data
epoch_trace = sess.run([h1, h2, h3],
feed_dict={inputs : x_trace})
epoch_trace = np.concatenate([layer_trace.T for layer_trace in epoch_trace])
history['trace'].append(epoch_trace)
There are a lot of metadata features we might be interested in visualizing, so we show you all of them here.
The most important element here is the first: the n_epochs x n_neurons x n_examples
tensor. We often refer to the flattened tensor as this is what will be returned by M-PHATE: this is the n_epochs*n_neurons x n_dim
matrix to be visualized.
# the actual data tensor
trace_data = np.array(history['trace'])
n_epochs = trace_data.shape[0]
n_neurons = trace_data.shape[1]
# the train loss for each element of the flattened trace
loss = np.repeat(history['loss'], n_neurons)
# the validation loss for each element of the flattened trace
val_loss = np.repeat(history['val_loss'], n_neurons)
# the train accuracy for each element of the flattened trace
accuracy = np.repeat(history['accuracy'], n_neurons)
# the validation accuracy for each element of the flattened trace
val_accuracy = np.repeat(history['val_accuracy'], n_neurons)
# the unique neuron id for each element of the flattened trace
neuron_ids = np.tile(np.arange(n_neurons), n_epochs)
# the hidden layer that each element of the flattened trace belongs to
layer_ids = np.tile(np.concatenate([np.repeat(i, int(layer.shape[1]))
for i, layer in enumerate([h1, h2, h3])]),
n_epochs)
# the current epoch for each element of the flattened trace
epoch = np.repeat(np.arange(n_epochs), n_neurons)
# the label of each digit we selected: this should be the same as `np.repeat(np.arange(10), 10)`
digit_ids = y_test.argmax(1)[trace_idx]
# normalize the data
trace_data_norm = m_phate.utils.normalize(trace_data)
# the average activity over digit labels for each element of the flattened trace
digit_activity = np.array([np.sum(np.abs(trace_data_norm[:, :, digit_ids == digit]), axis=2)
for digit in np.unique(digit_ids)])
# the digit label with the highest average activity for each element of the flattened trace
most_active_digit = np.argmax(digit_activity, axis=0).flatten()
# calculate M-PHATE
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(trace_data)
# plot the results
plt.rc('font', size=14)
fig, (ax1, ax2, ax3) = plt.subplots(
1, 3, figsize=(18, 6), sharex='all', sharey='all')
scprep.plot.scatter2d(m_phate_data, c=epoch, ax=ax1, ticks=False,
title='Epoch', label_prefix="M-PHATE")
scprep.plot.scatter2d(m_phate_data, c=layer_ids, ax=ax2, title='Layer',
ticks=False, label_prefix="M-PHATE")
scprep.plot.scatter2d(m_phate_data, c=most_active_digit, ax=ax3,
title='Most active digit',
ticks=False, label_prefix="M-PHATE")
plt.tight_layout()
Calculating M-PHATE... Calculating multislice kernel... Calculated multislice kernel in 10.64 seconds. Calculating graph and diffusion operator... Calculating landmark operator... Calculating SVD... Calculated SVD in 8.75 seconds. Calculating KMeans... Calculated KMeans in 64.01 seconds. Calculated landmark operator in 77.25 seconds. Calculated graph and diffusion operator in 77.88 seconds. Running PHATE on precomputed affinity matrix with 38400 cells. Calculating optimal t... Automatically selected t = 37 Calculated optimal t in 37.82 seconds. Calculating diffusion potential... Calculated diffusion potential in 20.41 seconds. Calculating metric MDS... Calculated metric MDS in 65.52 seconds. Calculated M-PHATE in 212.33 seconds.
Here we see the network begins with all hidden units looking approximately identical in the top right. As time progresses (see the left-hand plot) we see the hidden units develop into a heterogeneous spread. If we color these units by hidden layer (center) we see that the units are not separating by layer; that is, there are hidden units in all three layers playing similar roles. We can also color the units by their most active digit; that is, when averaging the activations of units across all the dimensions of the trace corresponding to each specific digit label, which of these labels has the highest average activation (in absolute value). We see some structure here in that units that are close to one another tend to have the same most active digit. If you have looked at the autoencoder notebook, you'll notice that the digit structure is much clearer here - this is because the network is being forced to separate digits by class label.
Let's separate out the layers and see if there are any differences that we can't see when they are plotting all together.
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for layer, ax in zip(range(3), axes.flatten()):
scprep.plot.scatter2d(m_phate_data[layer_ids==layer], c=most_active_digit[layer_ids==layer], ax=ax,
title="Layer {}".format(layer+1), ticks=False, label_prefix="M-PHATE", legend=layer==2,
legend_anchor=(1,1), legend_title="Most active digit")
plt.tight_layout()
Looks like layer 1 is substantially less heterogeneous than the following two layers. It makes sense that here, layer 1 is forming simpler features, where layers 2 and 3 are forming more complex abstract features which can be substantially more different from one another. Let's also color this plot by digit activity.
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for digit, ax in zip(range(10), axes.flatten()):
scprep.plot.scatter2d(m_phate_data, c=digit_activity[digit].flatten(), ax=ax,
title=digit, ticks=False, label_prefix="M-PHATE", legend=False)
plt.tight_layout()
We can see bands of hidden units that seem to respond most to certain digits which broadly lines up with what we saw in the 'most active digit' plot. Namely, the group of hidden units slightly separated above the others reponds primarily to 8, while a band at the bottom of the plot responds to 0.
Finally, we can also visualize the network loss on this plot to see how the evolution corresponds with improvements in performance.
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
scprep.plot.scatter2d(m_phate_data, c=loss, ax=ax1, ticks=False, cmap_scale='sqrt',
title='Training loss', label_prefix="M-PHATE")
scprep.plot.scatter2d(m_phate_data, c=val_loss, ax=ax2, ticks=False, cmap_scale='sqrt',
title='Validation loss', label_prefix="M-PHATE")
plt.tight_layout()
Just running M-PHATE on one network doesn't really tell us much. Let's build another network with the same architecture, but with dropout applied to the hidden units.
# input layer
inputs = tf.keras.layers.Input(
shape=(x_train.shape[1],), dtype='float32', name='inputs')
dropout_rate = tf.placeholder(tf.float32, name='dropout_rate')
# three dense hidden layers
h1 = tf.keras.layers.Dense(64, name='h1')(inputs)
h2 = tf.keras.layers.Dense(64, name='h2')(tf.nn.dropout(tf.nn.leaky_relu(h1), rate=dropout_rate))
h3 = tf.keras.layers.Dense(64, name='h3')(tf.nn.dropout(tf.nn.leaky_relu(h2), rate=dropout_rate))
# output layer
logits = tf.keras.layers.Dense(10, name='outputs')(tf.nn.dropout(tf.nn.leaky_relu(h3), rate=dropout_rate))
# true labels
labels = tf.keras.layers.Input(
shape=(y_train.shape[1],), dtype='float32', name='labels')
# cross entropy loss
loss_op = tf.losses.softmax_cross_entropy(labels, logits)
# categorical accuracy
_, accuracy_op = tf.metrics.accuracy(tf.argmax(labels, axis=1), tf.argmax(logits, axis=1))
# build the optimizer
optimizer = tf.train.AdamOptimizer(1e-5)
train_op = optimizer.minimize(loss_op)
# initialize variables
_ = sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# create a dictionary in which to store trace, loss, accuracy etc
history = {'trace':[], 'loss':[], 'val_loss':[], 'accuracy':[], 'val_accuracy':[]}
# train!
batch_size = 128
epochs = 200
verbose = False
for epoch in range(epochs):
# randomize batch order
batch_idx = np.random.choice(x_train.shape[0], x_train.shape[0], replace=False)
# run minibatch training
epoch_loss = []
epoch_acc = []
for batch in range(x_train.shape[0] // batch_size):
# get one batch worth of data
x_batch = x_train[batch_idx[batch * batch_size:(batch + 1) * batch_size]]
y_batch = y_train[batch_idx[batch * batch_size:(batch + 1) * batch_size]]
# gradient step
_, batch_loss, batch_acc = sess.run([train_op, loss_op, accuracy_op],
feed_dict={inputs : x_batch, labels : y_batch,
dropout_rate : 0.5})
# store loss and accuracy
epoch_loss.append(batch_loss)
epoch_acc.append(batch_acc)
# mean loss and accuracy for the epoch
epoch_loss = np.mean(epoch_loss)
epoch_acc = np.mean(epoch_acc)
history['loss'].append(epoch_loss)
history['accuracy'].append(epoch_acc)
# validation loss and accuracy
epoch_val_loss, epoch_val_acc = sess.run([loss_op, accuracy_op],
feed_dict={inputs : x_test, labels : y_test,
dropout_rate : 0.0})
history['val_loss'].append(epoch_val_loss)
history['val_accuracy'].append(epoch_val_acc)
if verbose:
print("Epoch {}/{}:\tloss={:.3f}\tval_loss={:.3f}\tacc={:.3f}\tval_acc={:.3f}".format(
epoch, epochs, epoch_loss, epoch_val_loss, epoch_acc, epoch_val_acc))
# store the activations of the trace data
epoch_trace = sess.run([h1, h2, h3],
feed_dict={inputs : x_trace, dropout_rate : 0.0})
epoch_trace = np.concatenate([layer_trace.T for layer_trace in epoch_trace])
history['trace'].append(epoch_trace)
# the actual data tensor
dropout_trace_data = np.array(history['trace'])
# normalize the data
dropout_trace_data_norm = m_phate.utils.normalize(dropout_trace_data)
# the average activity over digit labels for each element of the flattened trace
dropout_digit_activity = np.array([np.sum(np.abs(dropout_trace_data_norm[:, :, digit_ids == digit]), axis=2)
for digit in np.unique(digit_ids)])
# the digit label with the highest average activity for each element of the flattened trace
dropout_most_active_digit = np.argmax(dropout_digit_activity, axis=0).flatten()
# calculate M-PHATE
m_phate_dropout_op = m_phate.M_PHATE()
m_phate_dropout_data = m_phate_dropout_op.fit_transform(dropout_trace_data)
Calculating M-PHATE... Calculating multislice kernel... Calculated multislice kernel in 13.84 seconds. Calculating graph and diffusion operator... Calculating landmark operator... Calculating SVD... Calculated SVD in 7.92 seconds. Calculating KMeans... Calculated KMeans in 95.22 seconds. Calculated landmark operator in 107.31 seconds. Calculated graph and diffusion operator in 107.84 seconds. Running PHATE on precomputed affinity matrix with 38400 cells. Calculating optimal t... Automatically selected t = 49 Calculated optimal t in 36.36 seconds. Calculating diffusion potential... Calculated diffusion potential in 17.62 seconds. Calculating metric MDS... Calculated metric MDS in 223.95 seconds. Calculated M-PHATE in 399.65 seconds.
fig, (ax1, ax2) = plt.subplots(1,2,figsize=(13,6))
scprep.plot.scatter2d(m_phate_data, c=most_active_digit, ax=ax1,
title='Vanilla network',
ticks=False, label_prefix="M-PHATE",
legend_anchor=(1,1))
scprep.plot.scatter2d(m_phate_dropout_data, c=dropout_most_active_digit, ax=ax2,
title='Dropout network',
ticks=False, label_prefix="M-PHATE",
legend_anchor=(1,1), legend_title='Most active\ndigit')
plt.tight_layout()
Interestingly, the dropout network has spread the units out much more broadly than the vanilla network. We see this trend repeated across other types of regularizations, where a larger spread across the M-PHATE plot corresponds to higher generalization performance. For details, read our preprint on arXiv.
Just for fun, we can plot M-PHATE in 3D. Sometimes you can see structure more clearly this way.
# do it in 3D!
m_phate_dropout_op.set_params(n_components=3)
m_phate_dropout_data = m_phate_dropout_op.transform()
Calculating metric MDS... Calculated metric MDS in 385.89 seconds.
scprep.plot.scatter3d(m_phate_dropout_data, c=dropout_most_active_digit,
legend_title='Most active digit',
ticks=False, label_prefix="M-PHATE",
figsize=(8,8), legend_anchor=(1.5, 0.8))
<matplotlib.axes._subplots.Axes3DSubplot at 0x7f37420e8a90>
# rotating plot can be saved as a gif or mp4
scprep.plot.rotate_scatter3d(m_phate_dropout_data, c=dropout_most_active_digit,
title='Most active digit',
figsize=(8,8),
ticks=False, label_prefix="M-PHATE",
legend_anchor=(1.05, 0.8),
#filename='dropout3d.gif', fps=25, rotation_speed=45
)