Gregory Way 2017
This script is an extension of tybalt_vae.ipynb. See that script for more details about the base model. Here, I train two alternative Tybalt models with different architectures. Both architectures have two hidden layers:
This notebook trains both models. The optimal hyperparameters were selected through a grid search for each model independently.
The original tybalt model compressed 5000 input genes into 100 latent features in a single layer.
Much of this script is inspired by the keras variational_autoencoder.py example
For both models, the script will output:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.layers import Input, Dense, Lambda, Layer, Activation
from keras.layers.normalization import BatchNormalization
from keras.models import Model, Sequential
from keras import backend as K
from keras import metrics, optimizers
from keras.callbacks import Callback
import keras
import pydot
import graphviz
from keras.utils import plot_model
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
Using TensorFlow backend.
print(keras.__version__)
tf.__version__
2.0.5
'1.2.1'
%matplotlib inline
plt.style.use('seaborn-notebook')
np.random.seed(123)
This will facilitate connections between layers and also custom hyperparameters
# Function for reparameterization trick to make model differentiable
def sampling(args):
import tensorflow as tf
# Function with args required for Keras Lambda function
z_mean, z_log_var = args
# Draw epsilon of the same shape from a standard normal distribution
epsilon = K.random_normal(shape=tf.shape(z_mean), mean=0.,
stddev=epsilon_std)
# The latent vector is non-deterministic and differentiable
# in respect to z_mean and z_log_var
z = z_mean + K.exp(z_log_var / 2) * epsilon
return z
class CustomVariationalLayer(Layer):
"""
Define a custom layer that learns and performs the training
"""
def __init__(self, var_layer, mean_layer, **kwargs):
# https://keras.io/layers/writing-your-own-keras-layers/
self.is_placeholder = True
self.var_layer = var_layer
self.mean_layer = mean_layer
super(CustomVariationalLayer, self).__init__(**kwargs)
def vae_loss(self, x_input, x_decoded):
reconstruction_loss = original_dim * metrics.binary_crossentropy(x_input, x_decoded)
kl_loss = - 0.5 * K.sum(1 + self.var_layer - K.square(self.mean_layer) -
K.exp(self.var_layer), axis=-1)
return K.mean(reconstruction_loss + (K.get_value(beta) * kl_loss))
def call(self, inputs):
x = inputs[0]
x_decoded = inputs[1]
loss = self.vae_loss(x, x_decoded)
self.add_loss(loss, inputs=inputs)
# We won't actually use the output.
return x
This is modified code from https://github.com/fchollet/keras/issues/2595
class WarmUpCallback(Callback):
def __init__(self, beta, kappa):
self.beta = beta
self.kappa = kappa
# Behavior on each epoch
def on_epoch_end(self, epoch, logs={}):
if K.get_value(self.beta) <= 1:
K.set_value(self.beta, K.get_value(self.beta) + self.kappa)
The following class implements a Tybalt model with given input hyperparameters. Currently, only a two hidden layer model is supported.
class Tybalt():
"""
Facilitates the training and output of tybalt model trained on TCGA RNAseq gene expression data
"""
def __init__(self, original_dim, hidden_dim, latent_dim,
batch_size, epochs, learning_rate, kappa, beta):
self.original_dim = original_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.kappa = kappa
self.beta = beta
def build_encoder_layer(self):
# Input place holder for RNAseq data with specific input size
self.rnaseq_input = Input(shape=(self.original_dim, ))
# Input layer is compressed into a mean and log variance vector of size `latent_dim`
# Each layer is initialized with glorot uniform weights and each step (dense connections, batch norm,
# and relu activation) are funneled separately
# Each vector of length `latent_dim` are connected to the rnaseq input tensor
hidden_dense_linear = Dense(self.hidden_dim, kernel_initializer='glorot_uniform')(self.rnaseq_input)
hidden_dense_batchnorm = BatchNormalization()(hidden_dense_linear)
hidden_encoded = Activation('relu')(hidden_dense_batchnorm)
z_mean_dense_linear = Dense(self.latent_dim, kernel_initializer='glorot_uniform')(hidden_encoded)
z_mean_dense_batchnorm = BatchNormalization()(z_mean_dense_linear)
self.z_mean_encoded = Activation('relu')(z_mean_dense_batchnorm)
z_log_var_dense_linear = Dense(self.latent_dim, kernel_initializer='glorot_uniform')(hidden_encoded)
z_log_var_dense_batchnorm = BatchNormalization()(z_log_var_dense_linear)
self.z_log_var_encoded = Activation('relu')(z_log_var_dense_batchnorm)
# return the encoded and randomly sampled z vector
# Takes two keras layers as input to the custom sampling function layer with a `latent_dim` output
self.z = Lambda(sampling, output_shape=(self.latent_dim, ))([self.z_mean_encoded, self.z_log_var_encoded])
def build_decoder_layer(self):
# The decoding layer is much simpler with a single layer glorot uniform initialized and sigmoid activation
self.decoder_model = Sequential()
self.decoder_model.add(Dense(self.hidden_dim, activation='relu', input_dim=self.latent_dim))
self.decoder_model.add(Dense(self.original_dim, activation='sigmoid'))
self.rnaseq_reconstruct = self.decoder_model(self.z)
def compile_vae(self):
adam = optimizers.Adam(lr=self.learning_rate)
vae_layer = CustomVariationalLayer(self.z_log_var_encoded,
self.z_mean_encoded)([self.rnaseq_input, self.rnaseq_reconstruct])
self.vae = Model(self.rnaseq_input, vae_layer)
self.vae.compile(optimizer=adam, loss=None, loss_weights=[self.beta])
def get_summary(self):
self.vae.summary()
def visualize_architecture(self, output_file):
# Visualize the connections of the custom VAE model
plot_model(self.vae, to_file=output_file)
SVG(model_to_dot(self.vae).create(prog='dot', format='svg'))
def train_vae(self):
self.hist = self.vae.fit(np.array(rnaseq_train_df),
shuffle=True,
epochs=self.epochs,
batch_size=self.batch_size,
validation_data=(np.array(rnaseq_test_df), np.array(rnaseq_test_df)),
callbacks=[WarmUpCallback(self.beta, self.kappa)])
def visualize_training(self, output_file):
# Visualize training performance
history_df = pd.DataFrame(self.hist.history)
ax = history_df.plot()
ax.set_xlabel('Epochs')
ax.set_ylabel('VAE Loss')
fig = ax.get_figure()
fig.savefig(output_file)
def compress(self, df):
# Model to compress input
self.encoder = Model(self.rnaseq_input, self.z_mean_encoded)
# Encode rnaseq into the hidden/latent representation - and save output
encoded_df = self.encoder.predict_on_batch(df)
encoded_df = pd.DataFrame(encoded_df, columns=range(1, self.latent_dim + 1),
index=rnaseq_df.index)
return encoded_df
def get_decoder_weights(self):
# build a generator that can sample from the learned distribution
decoder_input = Input(shape=(self.latent_dim, )) # can generate from any sampled z vector
_x_decoded_mean = self.decoder_model(decoder_input)
self.decoder = Model(decoder_input, _x_decoded_mean)
weights = []
for layer in self.decoder.layers:
weights.append(layer.get_weights())
return(weights)
def predict(self, df):
return self.decoder.predict(np.array(df))
def save_models(self, encoder_file, decoder_file):
self.encoder.save(encoder_file)
self.decoder.save(decoder_file)
rnaseq_file = os.path.join('data', 'pancan_scaled_zeroone_rnaseq.tsv.gz')
rnaseq_df = pd.read_table(rnaseq_file, index_col=0)
print(rnaseq_df.shape)
rnaseq_df.head(2)
(10459, 5000)
RPS4Y1 | XIST | KRT5 | AGR2 | CEACAM5 | KRT6A | KRT14 | CEACAM6 | DDX3Y | KDM5D | ... | FAM129A | C8orf48 | CDK5R1 | FAM81A | C13orf18 | GDPD3 | SMAGP | C2orf85 | POU5F1B | CHST2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | 0.678296 | 0.289910 | 0.034230 | 0.0 | 0.0 | 0.084731 | 0.031863 | 0.037709 | 0.746797 | 0.687833 | ... | 0.440610 | 0.428782 | 0.732819 | 0.634340 | 0.580662 | 0.294313 | 0.458134 | 0.478219 | 0.168263 | 0.638497 |
TCGA-02-0055-01 | 0.200633 | 0.654917 | 0.181993 | 0.0 | 0.0 | 0.100606 | 0.050011 | 0.092586 | 0.103725 | 0.140642 | ... | 0.620658 | 0.363207 | 0.592269 | 0.602755 | 0.610192 | 0.374569 | 0.722420 | 0.271356 | 0.160465 | 0.602560 |
2 rows × 5000 columns
# Split 10% test set randomly
test_set_percent = 0.1
rnaseq_test_df = rnaseq_df.sample(frac=test_set_percent)
rnaseq_train_df = rnaseq_df.drop(rnaseq_test_df.index)
The hyperparameters provided below were determined through previous independent parameter searches
# Set common hyper parameters
original_dim = rnaseq_df.shape[1]
latent_dim = 100
beta = K.variable(0)
epsilon_std = 1.0
# Model A (100 hidden layer size)
model_a_latent_dim = 100
model_a_batch_size = 100
model_a_epochs = 100
model_a_learning_rate = 0.001
model_a_kappa = 1.0
# Model B (300 hidden layer size)
model_b_latent_dim = 300
model_b_batch_size = 50
model_b_epochs = 100
model_b_learning_rate = 0.0005
model_b_kappa = 0.01
model_a = Tybalt(original_dim=original_dim,
hidden_dim=model_a_latent_dim,
latent_dim=latent_dim,
batch_size=model_a_batch_size,
epochs=model_a_epochs,
learning_rate=model_a_learning_rate,
kappa=model_a_kappa,
beta=beta)
# Compile Model A
model_a.build_encoder_layer()
model_a.build_decoder_layer()
model_a.compile_vae()
model_a.get_summary()
____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== input_1 (InputLayer) (None, 5000) 0 ____________________________________________________________________________________________________ dense_1 (Dense) (None, 100) 500100 input_1[0][0] ____________________________________________________________________________________________________ batch_normalization_1 (BatchNorm (None, 100) 400 dense_1[0][0] ____________________________________________________________________________________________________ activation_1 (Activation) (None, 100) 0 batch_normalization_1[0][0] ____________________________________________________________________________________________________ dense_2 (Dense) (None, 100) 10100 activation_1[0][0] ____________________________________________________________________________________________________ dense_3 (Dense) (None, 100) 10100 activation_1[0][0] ____________________________________________________________________________________________________ batch_normalization_2 (BatchNorm (None, 100) 400 dense_2[0][0] ____________________________________________________________________________________________________ batch_normalization_3 (BatchNorm (None, 100) 400 dense_3[0][0] ____________________________________________________________________________________________________ activation_2 (Activation) (None, 100) 0 batch_normalization_2[0][0] ____________________________________________________________________________________________________ activation_3 (Activation) (None, 100) 0 batch_normalization_3[0][0] ____________________________________________________________________________________________________ lambda_1 (Lambda) (None, 100) 0 activation_2[0][0] activation_3[0][0] ____________________________________________________________________________________________________ sequential_1 (Sequential) (None, 5000) 515100 lambda_1[0][0] ____________________________________________________________________________________________________ custom_variational_layer_1 (Cust [(None, 5000), (None, 0 input_1[0][0] sequential_1[1][0] ==================================================================================================== Total params: 1,036,600 Trainable params: 1,036,000 Non-trainable params: 600 ____________________________________________________________________________________________________
/home/gway/anaconda3/envs/vae_pancancer/lib/python2.7/site-packages/ipykernel_launcher.py:52: UserWarning: Output "custom_variational_layer_1" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "custom_variational_layer_1" during training.
model_architecture_file = os.path.join('figures', 'twohidden_vae_architecture.png')
model_a.visualize_architecture(model_architecture_file)
%%time
model_a.train_vae()
Train on 9413 samples, validate on 1046 samples Epoch 1/100 9413/9413 [==============================] - 1s - loss: 2957.9805 - val_loss: 2950.8238 Epoch 2/100 9413/9413 [==============================] - 1s - loss: 2797.2792 - val_loss: 2805.8403 Epoch 3/100 9413/9413 [==============================] - 1s - loss: 2763.9802 - val_loss: 2762.6554 Epoch 4/100 9413/9413 [==============================] - 1s - loss: 2746.9909 - val_loss: 2754.8216 Epoch 5/100 9413/9413 [==============================] - 1s - loss: 2735.1533 - val_loss: 2740.8744 Epoch 6/100 9413/9413 [==============================] - 1s - loss: 2725.5900 - val_loss: 2726.4276 Epoch 7/100 9413/9413 [==============================] - 1s - loss: 2718.8524 - val_loss: 2724.1144 Epoch 8/100 9413/9413 [==============================] - 1s - loss: 2713.2076 - val_loss: 2716.0569 Epoch 9/100 9413/9413 [==============================] - 1s - loss: 2708.3757 - val_loss: 2711.0041 Epoch 10/100 9413/9413 [==============================] - 1s - loss: 2705.0074 - val_loss: 2707.2964 Epoch 11/100 9413/9413 [==============================] - 1s - loss: 2701.6027 - val_loss: 2704.6021 Epoch 12/100 9413/9413 [==============================] - 1s - loss: 2698.6276 - val_loss: 2702.7460 Epoch 13/100 9413/9413 [==============================] - 1s - loss: 2696.0863 - val_loss: 2696.5519 Epoch 14/100 9413/9413 [==============================] - 1s - loss: 2693.8027 - val_loss: 2695.8062 Epoch 15/100 9413/9413 [==============================] - 1s - loss: 2691.9961 - val_loss: 2694.1369 Epoch 16/100 9413/9413 [==============================] - 1s - loss: 2690.0699 - val_loss: 2691.5633 Epoch 17/100 9413/9413 [==============================] - 1s - loss: 2688.6888 - val_loss: 2690.7129 Epoch 18/100 9413/9413 [==============================] - 1s - loss: 2686.7476 - val_loss: 2689.8341 Epoch 19/100 9413/9413 [==============================] - 1s - loss: 2685.3433 - val_loss: 2687.7475 Epoch 20/100 9413/9413 [==============================] - 1s - loss: 2684.3288 - val_loss: 2685.7977 Epoch 21/100 9413/9413 [==============================] - 1s - loss: 2682.7771 - val_loss: 2684.1735 Epoch 22/100 9413/9413 [==============================] - 1s - loss: 2681.6806 - val_loss: 2685.5071 Epoch 23/100 9413/9413 [==============================] - 1s - loss: 2680.3572 - val_loss: 2683.1425 Epoch 24/100 9413/9413 [==============================] - 1s - loss: 2679.9060 - val_loss: 2684.7259 Epoch 25/100 9413/9413 [==============================] - 1s - loss: 2678.5832 - val_loss: 2683.2860 Epoch 26/100 9413/9413 [==============================] - 1s - loss: 2677.8243 - val_loss: 2682.6056 Epoch 27/100 9413/9413 [==============================] - 1s - loss: 2676.6036 - val_loss: 2679.3471 Epoch 28/100 9413/9413 [==============================] - 1s - loss: 2675.5763 - val_loss: 2678.8165 Epoch 29/100 9413/9413 [==============================] - 1s - loss: 2675.3098 - val_loss: 2678.9385 Epoch 30/100 9413/9413 [==============================] - 1s - loss: 2674.3548 - val_loss: 2676.9404 Epoch 31/100 9413/9413 [==============================] - 1s - loss: 2673.4457 - val_loss: 2676.3542 Epoch 32/100 9413/9413 [==============================] - 1s - loss: 2672.9657 - val_loss: 2676.0678 Epoch 33/100 9413/9413 [==============================] - 1s - loss: 2672.1828 - val_loss: 2676.0129 Epoch 34/100 9413/9413 [==============================] - 1s - loss: 2671.5003 - val_loss: 2675.2568 Epoch 35/100 9413/9413 [==============================] - 1s - loss: 2671.0813 - val_loss: 2675.5867 Epoch 36/100 9413/9413 [==============================] - 1s - loss: 2670.7040 - val_loss: 2673.9413 Epoch 37/100 9413/9413 [==============================] - 1s - loss: 2669.7803 - val_loss: 2673.5747 Epoch 38/100 9413/9413 [==============================] - 1s - loss: 2669.2723 - val_loss: 2673.5660 Epoch 39/100 9413/9413 [==============================] - 1s - loss: 2669.0598 - val_loss: 2673.2244 Epoch 40/100 9413/9413 [==============================] - 1s - loss: 2668.3699 - val_loss: 2675.0969 Epoch 41/100 9413/9413 [==============================] - 1s - loss: 2668.1428 - val_loss: 2671.7506 Epoch 42/100 9413/9413 [==============================] - 1s - loss: 2667.3772 - val_loss: 2670.8989 Epoch 43/100 9413/9413 [==============================] - 1s - loss: 2667.0446 - val_loss: 2670.1845 Epoch 44/100 9413/9413 [==============================] - 1s - loss: 2666.6668 - val_loss: 2669.2487 Epoch 45/100 9413/9413 [==============================] - 1s - loss: 2666.3487 - val_loss: 2671.3058 Epoch 46/100 9413/9413 [==============================] - 1s - loss: 2665.8814 - val_loss: 2669.0087 Epoch 47/100 9413/9413 [==============================] - 1s - loss: 2665.5644 - val_loss: 2669.0123 Epoch 48/100 9413/9413 [==============================] - 1s - loss: 2665.0417 - val_loss: 2668.4945 Epoch 49/100 9413/9413 [==============================] - 1s - loss: 2664.9909 - val_loss: 2668.5532 Epoch 50/100 9413/9413 [==============================] - 1s - loss: 2664.8526 - val_loss: 2667.8724 Epoch 51/100 9413/9413 [==============================] - 1s - loss: 2664.0827 - val_loss: 2668.5496 Epoch 52/100 9413/9413 [==============================] - 1s - loss: 2664.0727 - val_loss: 2669.2049 Epoch 53/100 9413/9413 [==============================] - 1s - loss: 2663.9176 - val_loss: 2668.2468 Epoch 54/100 9413/9413 [==============================] - 1s - loss: 2663.1647 - val_loss: 2666.8234 Epoch 55/100 9413/9413 [==============================] - 1s - loss: 2662.9519 - val_loss: 2667.1615 Epoch 56/100 9413/9413 [==============================] - 1s - loss: 2662.8080 - val_loss: 2666.0201 Epoch 57/100 9413/9413 [==============================] - 1s - loss: 2662.3003 - val_loss: 2666.1678 Epoch 58/100 9413/9413 [==============================] - 1s - loss: 2661.6830 - val_loss: 2666.0583 Epoch 59/100 9413/9413 [==============================] - 1s - loss: 2662.1333 - val_loss: 2666.6102 Epoch 60/100 9413/9413 [==============================] - 1s - loss: 2661.8820 - val_loss: 2665.8327 Epoch 61/100 9413/9413 [==============================] - 1s - loss: 2661.2620 - val_loss: 2665.5685 Epoch 62/100 9413/9413 [==============================] - 1s - loss: 2660.8630 - val_loss: 2665.3976 Epoch 63/100 9413/9413 [==============================] - 1s - loss: 2660.9776 - val_loss: 2663.9852 Epoch 64/100 9413/9413 [==============================] - 1s - loss: 2660.5806 - val_loss: 2666.0981 Epoch 65/100 9413/9413 [==============================] - 1s - loss: 2660.9485 - val_loss: 2665.4582 Epoch 66/100 9413/9413 [==============================] - 1s - loss: 2659.9776 - val_loss: 2663.9994 Epoch 67/100 9413/9413 [==============================] - 1s - loss: 2660.1682 - val_loss: 2664.4093 Epoch 68/100 9413/9413 [==============================] - 1s - loss: 2659.8602 - val_loss: 2665.4123 Epoch 69/100 9413/9413 [==============================] - 1s - loss: 2659.7320 - val_loss: 2664.2095 Epoch 70/100 9413/9413 [==============================] - 1s - loss: 2659.3986 - val_loss: 2663.3767 Epoch 71/100 9413/9413 [==============================] - 1s - loss: 2659.4631 - val_loss: 2662.3696 Epoch 72/100 9413/9413 [==============================] - 1s - loss: 2658.8762 - val_loss: 2663.2382 Epoch 73/100 9413/9413 [==============================] - 1s - loss: 2658.8785 - val_loss: 2662.9145 Epoch 74/100 9413/9413 [==============================] - 1s - loss: 2658.6419 - val_loss: 2662.6420 Epoch 75/100 9413/9413 [==============================] - 1s - loss: 2658.0485 - val_loss: 2662.2743 Epoch 76/100 9413/9413 [==============================] - 1s - loss: 2657.9589 - val_loss: 2662.9043 Epoch 77/100 9413/9413 [==============================] - 1s - loss: 2657.9868 - val_loss: 2661.9093 Epoch 78/100 9413/9413 [==============================] - 1s - loss: 2657.9815 - val_loss: 2662.5459 Epoch 79/100 9413/9413 [==============================] - 1s - loss: 2658.0705 - val_loss: 2661.9623 Epoch 80/100 9413/9413 [==============================] - 1s - loss: 2657.8526 - val_loss: 2662.4854 Epoch 81/100 9413/9413 [==============================] - 1s - loss: 2657.5689 - val_loss: 2661.9085 Epoch 82/100 9413/9413 [==============================] - 1s - loss: 2657.6613 - val_loss: 2661.5535 Epoch 83/100 9413/9413 [==============================] - 1s - loss: 2657.1739 - val_loss: 2661.2428 Epoch 84/100 9413/9413 [==============================] - 1s - loss: 2657.2324 - val_loss: 2662.6703 Epoch 85/100 9413/9413 [==============================] - 1s - loss: 2656.8949 - val_loss: 2661.0031 Epoch 86/100 9413/9413 [==============================] - 1s - loss: 2656.6910 - val_loss: 2660.5960 Epoch 87/100 9413/9413 [==============================] - 1s - loss: 2656.8326 - val_loss: 2661.4506 Epoch 88/100 9413/9413 [==============================] - 1s - loss: 2656.7350 - val_loss: 2660.4537 Epoch 89/100 9413/9413 [==============================] - 1s - loss: 2656.3786 - val_loss: 2661.3030 Epoch 90/100 9413/9413 [==============================] - 1s - loss: 2656.1899 - val_loss: 2660.8402 Epoch 91/100 9413/9413 [==============================] - 1s - loss: 2656.3453 - val_loss: 2660.1869 Epoch 92/100 9413/9413 [==============================] - 1s - loss: 2655.8668 - val_loss: 2661.4266 Epoch 93/100 9413/9413 [==============================] - 1s - loss: 2655.7987 - val_loss: 2661.0794 Epoch 94/100 9413/9413 [==============================] - 1s - loss: 2655.6291 - val_loss: 2660.4709 Epoch 95/100 9413/9413 [==============================] - 1s - loss: 2655.5421 - val_loss: 2659.9211 Epoch 96/100 9413/9413 [==============================] - 1s - loss: 2655.5208 - val_loss: 2659.5040 Epoch 97/100 9413/9413 [==============================] - 1s - loss: 2655.2294 - val_loss: 2659.9641 Epoch 98/100 9413/9413 [==============================] - 1s - loss: 2655.2495 - val_loss: 2659.8988 Epoch 99/100 9413/9413 [==============================] - 1s - loss: 2655.1405 - val_loss: 2661.1199 Epoch 100/100 9413/9413 [==============================] - 1s - loss: 2655.0629 - val_loss: 2659.8728 CPU times: user 2min 14s, sys: 6.39 s, total: 2min 20s Wall time: 1min 53s
model_a_training_file = os.path.join('figures', 'twohidden_100hidden_training.pdf')
model_a.visualize_training(model_a_training_file)
model_a_compression = model_a.compress(rnaseq_df)
model_a_file = os.path.join('data', 'encoded_rnaseq_twohidden_100model.tsv.gz')
model_a_compression.to_csv(model_a_file, sep='\t', compression='gzip')
model_a_compression.head(2)
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | 1.804567 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 | 2.082371 | 0.000000 | 0.0 | ... | 5.628737 | 0.882983 | 0.0 | 0.000000 | 1.976136 | 1.912838 | 3.621609 | 0.000000 | 1.947124 | 1.840908 |
TCGA-02-0055-01 | 0.635178 | 0.0 | 1.591518 | 0.029515 | 1.855888 | 0.0 | 0.0 | 4.964176 | 1.741375 | 0.0 | ... | 1.160538 | 0.000000 | 0.0 | 1.639663 | 0.000000 | 0.000000 | 4.046312 | 0.304179 | 6.382465 | 0.919127 |
2 rows × 100 columns
model_a_weights = model_a.get_decoder_weights()
encoder_model_a_file = os.path.join('models', 'encoder_twohidden100_vae.hdf5')
decoder_model_a_file = os.path.join('models', 'decoder_twohidden100_vae.hdf5')
model_a.save_models(encoder_model_a_file, decoder_model_a_file)
model_b = Tybalt(original_dim=original_dim,
hidden_dim=model_b_latent_dim,
latent_dim=latent_dim,
batch_size=model_b_batch_size,
epochs=model_b_epochs,
learning_rate=model_b_learning_rate,
kappa=model_b_kappa,
beta=beta)
# Compile Model B
model_b.build_encoder_layer()
model_b.build_decoder_layer()
model_b.compile_vae()
model_b.get_summary()
____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== input_3 (InputLayer) (None, 5000) 0 ____________________________________________________________________________________________________ dense_6 (Dense) (None, 300) 1500300 input_3[0][0] ____________________________________________________________________________________________________ batch_normalization_4 (BatchNorm (None, 300) 1200 dense_6[0][0] ____________________________________________________________________________________________________ activation_4 (Activation) (None, 300) 0 batch_normalization_4[0][0] ____________________________________________________________________________________________________ dense_7 (Dense) (None, 100) 30100 activation_4[0][0] ____________________________________________________________________________________________________ dense_8 (Dense) (None, 100) 30100 activation_4[0][0] ____________________________________________________________________________________________________ batch_normalization_5 (BatchNorm (None, 100) 400 dense_7[0][0] ____________________________________________________________________________________________________ batch_normalization_6 (BatchNorm (None, 100) 400 dense_8[0][0] ____________________________________________________________________________________________________ activation_5 (Activation) (None, 100) 0 batch_normalization_5[0][0] ____________________________________________________________________________________________________ activation_6 (Activation) (None, 100) 0 batch_normalization_6[0][0] ____________________________________________________________________________________________________ lambda_2 (Lambda) (None, 100) 0 activation_5[0][0] activation_6[0][0] ____________________________________________________________________________________________________ sequential_2 (Sequential) (None, 5000) 1535300 lambda_2[0][0] ____________________________________________________________________________________________________ custom_variational_layer_2 (Cust [(None, 5000), (None, 0 input_3[0][0] sequential_2[1][0] ==================================================================================================== Total params: 3,097,800 Trainable params: 3,096,800 Non-trainable params: 1,000 ____________________________________________________________________________________________________
/home/gway/anaconda3/envs/vae_pancancer/lib/python2.7/site-packages/ipykernel_launcher.py:52: UserWarning: Output "custom_variational_layer_2" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "custom_variational_layer_2" during training.
%%time
model_b.train_vae()
Train on 9413 samples, validate on 1046 samples Epoch 1/100 9413/9413 [==============================] - 2s - loss: 2976.0963 - val_loss: 2909.1827 Epoch 2/100 9413/9413 [==============================] - 1s - loss: 2859.1518 - val_loss: 2852.9602 Epoch 3/100 9413/9413 [==============================] - 1s - loss: 2835.7822 - val_loss: 2827.6078 Epoch 4/100 9413/9413 [==============================] - 1s - loss: 2822.8338 - val_loss: 2822.6562 Epoch 5/100 9413/9413 [==============================] - 1s - loss: 2814.4313 - val_loss: 2811.0809 Epoch 6/100 9413/9413 [==============================] - 1s - loss: 2807.8876 - val_loss: 2807.1923 Epoch 7/100 9413/9413 [==============================] - 1s - loss: 2803.3193 - val_loss: 2801.9632 Epoch 8/100 9413/9413 [==============================] - 1s - loss: 2800.3531 - val_loss: 2801.8731 Epoch 9/100 9413/9413 [==============================] - 1s - loss: 2796.8033 - val_loss: 2795.0505 Epoch 10/100 9413/9413 [==============================] - 1s - loss: 2793.8498 - val_loss: 2792.7511 Epoch 11/100 9413/9413 [==============================] - 1s - loss: 2792.0580 - val_loss: 2791.4050 Epoch 12/100 9413/9413 [==============================] - 1s - loss: 2789.9623 - val_loss: 2790.5441 Epoch 13/100 9413/9413 [==============================] - 1s - loss: 2788.6482 - val_loss: 2786.7315 Epoch 14/100 9413/9413 [==============================] - 1s - loss: 2787.1517 - val_loss: 2786.8334 Epoch 15/100 9413/9413 [==============================] - 1s - loss: 2785.5486 - val_loss: 2782.8495 Epoch 16/100 9413/9413 [==============================] - 1s - loss: 2784.6404 - val_loss: 2783.8522 Epoch 17/100 9413/9413 [==============================] - 1s - loss: 2782.7719 - val_loss: 2783.0421 Epoch 18/100 9413/9413 [==============================] - 1s - loss: 2781.9863 - val_loss: 2781.9717 Epoch 19/100 9413/9413 [==============================] - 1s - loss: 2781.1133 - val_loss: 2779.6485 Epoch 20/100 9413/9413 [==============================] - 2s - loss: 2780.4245 - val_loss: 2780.3831 Epoch 21/100 9413/9413 [==============================] - 2s - loss: 2779.1302 - val_loss: 2778.6101 Epoch 22/100 9413/9413 [==============================] - 2s - loss: 2778.6055 - val_loss: 2777.9476 Epoch 23/100 9413/9413 [==============================] - 2s - loss: 2777.6820 - val_loss: 2777.0861 Epoch 24/100 9413/9413 [==============================] - 2s - loss: 2775.9326 - val_loss: 2775.4428 Epoch 25/100 9413/9413 [==============================] - 2s - loss: 2776.2956 - val_loss: 2776.0726 Epoch 26/100 9413/9413 [==============================] - 2s - loss: 2775.1535 - val_loss: 2775.9854 Epoch 27/100 9413/9413 [==============================] - 2s - loss: 2775.5803 - val_loss: 2775.4990 Epoch 28/100 9413/9413 [==============================] - 2s - loss: 2773.7269 - val_loss: 2774.2552 Epoch 29/100 9413/9413 [==============================] - 2s - loss: 2773.2844 - val_loss: 2773.6002 Epoch 30/100 9413/9413 [==============================] - 2s - loss: 2773.0693 - val_loss: 2771.3928 Epoch 31/100 9413/9413 [==============================] - 2s - loss: 2772.3376 - val_loss: 2771.8702 Epoch 32/100 9413/9413 [==============================] - 2s - loss: 2772.1800 - val_loss: 2772.0828 Epoch 33/100 9413/9413 [==============================] - 2s - loss: 2771.9286 - val_loss: 2770.8244 Epoch 34/100 9413/9413 [==============================] - 2s - loss: 2771.4918 - val_loss: 2769.4966 Epoch 35/100 9413/9413 [==============================] - 1s - loss: 2770.4796 - val_loss: 2770.5737 Epoch 36/100 9413/9413 [==============================] - 1s - loss: 2770.3200 - val_loss: 2769.5569 Epoch 37/100 9413/9413 [==============================] - 1s - loss: 2769.6836 - val_loss: 2769.1391 Epoch 38/100 9413/9413 [==============================] - 1s - loss: 2769.3641 - val_loss: 2767.8874 Epoch 39/100 9413/9413 [==============================] - 1s - loss: 2768.3753 - val_loss: 2769.2286 Epoch 40/100 9413/9413 [==============================] - 1s - loss: 2768.1522 - val_loss: 2768.8916 Epoch 41/100 9413/9413 [==============================] - 1s - loss: 2767.7560 - val_loss: 2767.6186 Epoch 42/100 9413/9413 [==============================] - 1s - loss: 2767.4286 - val_loss: 2766.8226 Epoch 43/100 9413/9413 [==============================] - 1s - loss: 2766.9297 - val_loss: 2766.3980 Epoch 44/100 9413/9413 [==============================] - 1s - loss: 2766.3725 - val_loss: 2767.5063 Epoch 45/100 9413/9413 [==============================] - 1s - loss: 2766.4936 - val_loss: 2766.4490 Epoch 46/100 9413/9413 [==============================] - 1s - loss: 2766.0341 - val_loss: 2766.3559 Epoch 47/100 9413/9413 [==============================] - 1s - loss: 2765.7350 - val_loss: 2764.8949 Epoch 48/100 9413/9413 [==============================] - 1s - loss: 2765.1965 - val_loss: 2766.1944 Epoch 49/100 9413/9413 [==============================] - 1s - loss: 2765.0791 - val_loss: 2766.0875 Epoch 50/100 9413/9413 [==============================] - 1s - loss: 2764.7780 - val_loss: 2763.8616 Epoch 51/100 9413/9413 [==============================] - 1s - loss: 2764.0595 - val_loss: 2763.8204 Epoch 52/100 9413/9413 [==============================] - 1s - loss: 2763.9017 - val_loss: 2765.7488 Epoch 53/100 9413/9413 [==============================] - 1s - loss: 2764.2699 - val_loss: 2764.0976 Epoch 54/100 9413/9413 [==============================] - 1s - loss: 2763.4282 - val_loss: 2763.4282 Epoch 55/100 9413/9413 [==============================] - 1s - loss: 2763.1986 - val_loss: 2764.0629 Epoch 56/100 9413/9413 [==============================] - 1s - loss: 2763.2241 - val_loss: 2763.3203 Epoch 57/100 9413/9413 [==============================] - 1s - loss: 2762.3466 - val_loss: 2762.2408 Epoch 58/100 9413/9413 [==============================] - 1s - loss: 2763.0907 - val_loss: 2763.5929 Epoch 59/100 9413/9413 [==============================] - 1s - loss: 2761.6403 - val_loss: 2762.6243 Epoch 60/100 9413/9413 [==============================] - 1s - loss: 2762.1639 - val_loss: 2763.6944 Epoch 61/100 9413/9413 [==============================] - 1s - loss: 2761.9182 - val_loss: 2763.8620 Epoch 62/100 9413/9413 [==============================] - 1s - loss: 2761.4355 - val_loss: 2761.3378 Epoch 63/100 9413/9413 [==============================] - 1s - loss: 2760.8157 - val_loss: 2761.7237 Epoch 64/100 9413/9413 [==============================] - 1s - loss: 2761.3420 - val_loss: 2760.0974 Epoch 65/100 9413/9413 [==============================] - 1s - loss: 2760.7411 - val_loss: 2763.7263 Epoch 66/100 9413/9413 [==============================] - 1s - loss: 2760.1829 - val_loss: 2758.5567 Epoch 67/100 9413/9413 [==============================] - 1s - loss: 2759.7191 - val_loss: 2761.6695 Epoch 68/100 9413/9413 [==============================] - 1s - loss: 2760.0463 - val_loss: 2762.2958 Epoch 69/100 9413/9413 [==============================] - 1s - loss: 2759.4546 - val_loss: 2758.3040 Epoch 70/100 9413/9413 [==============================] - 1s - loss: 2759.8041 - val_loss: 2759.9196 Epoch 71/100 9413/9413 [==============================] - 1s - loss: 2760.5095 - val_loss: 2760.8132 Epoch 72/100 9413/9413 [==============================] - 1s - loss: 2759.6557 - val_loss: 2758.9026 Epoch 73/100 9413/9413 [==============================] - 1s - loss: 2759.4091 - val_loss: 2758.6991 Epoch 74/100 9413/9413 [==============================] - 1s - loss: 2759.1058 - val_loss: 2760.5453 Epoch 75/100 9413/9413 [==============================] - 2s - loss: 2758.9385 - val_loss: 2758.7857 Epoch 76/100 9413/9413 [==============================] - 1s - loss: 2758.2842 - val_loss: 2760.3362 Epoch 77/100 9413/9413 [==============================] - 1s - loss: 2758.7999 - val_loss: 2758.8773 Epoch 78/100 9413/9413 [==============================] - 1s - loss: 2758.6116 - val_loss: 2758.2635 Epoch 79/100 9413/9413 [==============================] - 1s - loss: 2757.9021 - val_loss: 2759.7957 Epoch 80/100 9413/9413 [==============================] - 1s - loss: 2757.8185 - val_loss: 2758.8604 Epoch 81/100 9413/9413 [==============================] - 1s - loss: 2758.2601 - val_loss: 2758.0821 Epoch 82/100 9413/9413 [==============================] - 1s - loss: 2757.5211 - val_loss: 2759.0933 Epoch 83/100 9413/9413 [==============================] - 1s - loss: 2757.1283 - val_loss: 2757.9790 Epoch 84/100 9413/9413 [==============================] - 1s - loss: 2757.3255 - val_loss: 2759.0767 Epoch 85/100 9413/9413 [==============================] - 1s - loss: 2756.2826 - val_loss: 2756.8941 Epoch 86/100 9413/9413 [==============================] - 2s - loss: 2756.2938 - val_loss: 2759.5235 Epoch 87/100 9413/9413 [==============================] - 2s - loss: 2756.8478 - val_loss: 2758.8628 Epoch 88/100 9413/9413 [==============================] - 1s - loss: 2756.7957 - val_loss: 2757.2729 Epoch 89/100 9413/9413 [==============================] - 1s - loss: 2756.4703 - val_loss: 2757.6702 Epoch 90/100 9413/9413 [==============================] - 1s - loss: 2756.3907 - val_loss: 2756.1392 Epoch 91/100 9413/9413 [==============================] - 1s - loss: 2756.1861 - val_loss: 2757.5099 Epoch 92/100 9413/9413 [==============================] - 1s - loss: 2756.2406 - val_loss: 2756.3702 Epoch 93/100 9413/9413 [==============================] - 1s - loss: 2755.4179 - val_loss: 2758.0628 Epoch 94/100 9413/9413 [==============================] - 1s - loss: 2755.9896 - val_loss: 2756.5850 Epoch 95/100 9413/9413 [==============================] - 1s - loss: 2755.6997 - val_loss: 2757.0054 Epoch 96/100 9413/9413 [==============================] - 1s - loss: 2754.9688 - val_loss: 2757.4058 Epoch 97/100 9413/9413 [==============================] - 1s - loss: 2755.1529 - val_loss: 2757.3296 Epoch 98/100 9413/9413 [==============================] - 1s - loss: 2754.5996 - val_loss: 2756.8905 Epoch 99/100 9413/9413 [==============================] - 1s - loss: 2755.0214 - val_loss: 2755.1875 Epoch 100/100 9413/9413 [==============================] - 1s - loss: 2754.6306 - val_loss: 2755.2204 CPU times: user 3min 51s, sys: 16.1 s, total: 4min 7s Wall time: 3min 12s
model_b_training_file = os.path.join('figures', 'twohidden_300hidden_training.pdf')
model_b.visualize_training(model_b_training_file)
model_b_compression = model_b.compress(rnaseq_df)
model_b_file = os.path.join('data', 'encoded_rnaseq_twohidden_300model.tsv.gz')
model_b_compression.to_csv(model_b_file, sep='\t', compression='gzip')
model_b_compression.head(2)
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | 0.000000 | 0.0 | 1.250155 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.95250 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
TCGA-02-0055-01 | 0.556497 | 0.0 | 0.056864 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.13956 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 rows × 100 columns
model_b_weights = model_b.get_decoder_weights()
encoder_model_b_file = os.path.join('models', 'encoder_twohidden300_vae.hdf5')
decoder_model_b_file = os.path.join('models', 'decoder_twohidden300_vae.hdf5')
model_b.save_models(encoder_model_b_file, decoder_model_b_file)
In a two layer model, there are two sets of learned weights between samples and latent features and from latent features to reconstructed samples. The first set connects the genes to the hidden layer and the second set connects the hidden layer to the latent feature activation. The two layers can be connected by matrix multiplication, which provides a direct connection from gene to latent feature. It is likely that the two layers learn different biological features, but the immediate connection is easiest to currently analyze and intuit.
def extract_weights(weights, weight_file):
# Multiply hidden layers together to obtain a single representation of gene weights
intermediate_weight_df = pd.DataFrame(weights[1][0])
hidden_weight_df = pd.DataFrame(weights[1][2])
abstracted_weight_df = intermediate_weight_df.dot(hidden_weight_df)
abstracted_weight_df.index = range(1, 101)
abstracted_weight_df.columns = rnaseq_df.columns
abstracted_weight_df.to_csv(weight_file, sep='\t')
# Model A
model_a_weight_file = os.path.join('data', 'tybalt_gene_weights_twohidden100.tsv')
extract_weights(model_a_weights, model_a_weight_file)
# Model B
model_b_weight_file = os.path.join('data', 'tybalt_gene_weights_twohidden300.tsv')
extract_weights(model_b_weights, model_b_weight_file)