Gregory Way 2017
This script trains a denoising autoencoder for cancer gene expression data using Keras. It modifies the framework presented by the ADAGE (Analysis using denoising autoencoders of gene expression) model published by Tan et al 2015.
An ADAGE model learns a non-linear, reduced dimensional representation of gene expression data by bottlenecking raw features into a smaller set. The model is then trained by minimizing the information lost between input and reconstructed input.
The specific model trained in this notebook consists of gene expression input (5000 most variably expressed genes by median absolute deviation) compressed down into one length 100 vector. The hidden layer is then decoded back to the original 5000 dimensions. The encoding (compression) layer has a relu
activation and the decoding layer has a sigmoid
activation. The weights of each layer are glorot uniform initialized. We include an l1 regularization term (see keras.regularizers.l1
for more details) to induce sparsity in the model, as well as a term controlling the probability of input feature dropout. This is only active during training and is the denoising aspect of the model. See keras.layers.noise.Dropout
for more details.
We train the autoencoder with the Adadelta optimizer and MSE reconstruction loss.
The pan-cancer ADAGE model is similar to tybalt, but does not constrain the features to match a Gaussian distribution. It is an active research question if the VAE learned features provide any additional benefits over ADAGE features. The VAE is a generative model and therefore permits easy generation of fake data. Additionally, we hypothesize that the VAE learns a manifold that can be interpolated to extract meaningful biological knowledge.
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pydot
import graphviz
from keras.utils import plot_model
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.layers import Input, Dense, Dropout, Activation
from keras.layers.noise import GaussianDropout
from keras.models import Model
from keras.regularizers import l1
from keras import optimizers
import keras
Using TensorFlow backend.
print(keras.__version__)
2.0.6
%matplotlib inline
plt.style.use('seaborn-notebook')
sns.set(style="white", color_codes=True)
sns.set_context("paper", rc={"font.size":14,"axes.titlesize":15,"axes.labelsize":20,
'xtick.labelsize':14, 'ytick.labelsize':14})
# Load RNAseq data
rnaseq_file = os.path.join('data', 'pancan_scaled_zeroone_rnaseq.tsv.gz')
rnaseq_df = pd.read_table(rnaseq_file, index_col=0)
print(rnaseq_df.shape)
rnaseq_df.head(2)
(10459, 5000)
RPS4Y1 | XIST | KRT5 | AGR2 | CEACAM5 | KRT6A | KRT14 | CEACAM6 | DDX3Y | KDM5D | ... | FAM129A | C8orf48 | CDK5R1 | FAM81A | C13orf18 | GDPD3 | SMAGP | C2orf85 | POU5F1B | CHST2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | 0.678296 | 0.289910 | 0.034230 | 0.0 | 0.0 | 0.084731 | 0.031863 | 0.037709 | 0.746797 | 0.687833 | ... | 0.440610 | 0.428782 | 0.732819 | 0.634340 | 0.580662 | 0.294313 | 0.458134 | 0.478219 | 0.168263 | 0.638497 |
TCGA-02-0055-01 | 0.200633 | 0.654917 | 0.181993 | 0.0 | 0.0 | 0.100606 | 0.050011 | 0.092586 | 0.103725 | 0.140642 | ... | 0.620658 | 0.363207 | 0.592269 | 0.602755 | 0.610192 | 0.374569 | 0.722420 | 0.271356 | 0.160465 | 0.602560 |
2 rows × 5000 columns
np.random.seed(123)
# Split 10% test set randomly
test_set_percent = 0.1
rnaseq_test_df = rnaseq_df.sample(frac=test_set_percent)
rnaseq_train_df = rnaseq_df.drop(rnaseq_test_df.index)
We previously performed a parameter sweep search over a grid of potential hyperparameter values. Based on this sweep, we determined that the optimal ADAGE parameters are:
Parameter | Optimal Setting |
---|---|
Learning Rate | 1.1 |
Sparsity | 0 |
Noise | 0.05 |
Epochs | 100 |
Batch Size | 50 |
num_features = rnaseq_df.shape[1]
encoding_dim = 100
sparsity = 0
noise = 0.05
epochs = 100
batch_size = 50
learning_rate = 1.1
# Build the Keras graph
input_rnaseq = Input(shape=(num_features, ))
encoded_rnaseq = Dropout(noise)(input_rnaseq)
encoded_rnaseq_2 = Dense(encoding_dim,
activity_regularizer=l1(sparsity))(encoded_rnaseq)
activation = Activation('relu')(encoded_rnaseq_2)
decoded_rnaseq = Dense(num_features, activation='sigmoid')(activation)
autoencoder = Model(input_rnaseq, decoded_rnaseq)
autoencoder.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 5000) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 5000) 0 _________________________________________________________________ dense_1 (Dense) (None, 100) 500100 _________________________________________________________________ activation_1 (Activation) (None, 100) 0 _________________________________________________________________ dense_2 (Dense) (None, 5000) 505000 ================================================================= Total params: 1,005,100 Trainable params: 1,005,100 Non-trainable params: 0 _________________________________________________________________
# Visualize the connections of the custom VAE model
output_model_file = os.path.join('figures', 'adage_architecture.png')
plot_model(autoencoder, to_file=output_model_file)
SVG(model_to_dot(autoencoder).create(prog='dot', format='svg'))
# Separate out the encoder and decoder model
encoder = Model(input_rnaseq, encoded_rnaseq_2)
encoded_input = Input(shape=(encoding_dim, ))
decoder_layer = autoencoder.layers[-1]
decoder = Model(encoded_input, decoder_layer(encoded_input))
# Compile the autoencoder to prepare for training
adadelta = optimizers.Adadelta(lr=learning_rate)
autoencoder.compile(optimizer=adadelta, loss='mse')
%%time
hist = autoencoder.fit(np.array(rnaseq_train_df), np.array(rnaseq_train_df),
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(np.array(rnaseq_test_df), np.array(rnaseq_test_df)))
Train on 9413 samples, validate on 1046 samples Epoch 1/100 9413/9413 [==============================] - 4s - loss: 0.0570 - val_loss: 0.0365 Epoch 2/100 9413/9413 [==============================] - 4s - loss: 0.0358 - val_loss: 0.0359 Epoch 3/100 9413/9413 [==============================] - 5s - loss: 0.0355 - val_loss: 0.0357 Epoch 4/100 9413/9413 [==============================] - 6s - loss: 0.0352 - val_loss: 0.0353 Epoch 5/100 9413/9413 [==============================] - 5s - loss: 0.0348 - val_loss: 0.0348 Epoch 6/100 9413/9413 [==============================] - 4s - loss: 0.0342 - val_loss: 0.0341 Epoch 7/100 9413/9413 [==============================] - 5s - loss: 0.0334 - val_loss: 0.0331 Epoch 8/100 9413/9413 [==============================] - 4s - loss: 0.0323 - val_loss: 0.0321 Epoch 9/100 9413/9413 [==============================] - 4s - loss: 0.0313 - val_loss: 0.0311 Epoch 10/100 9413/9413 [==============================] - 4s - loss: 0.0305 - val_loss: 0.0303 Epoch 11/100 9413/9413 [==============================] - 4s - loss: 0.0298 - val_loss: 0.0296 Epoch 12/100 9413/9413 [==============================] - 4s - loss: 0.0292 - val_loss: 0.0291 Epoch 13/100 9413/9413 [==============================] - 4s - loss: 0.0287 - val_loss: 0.0285 Epoch 14/100 9413/9413 [==============================] - 4s - loss: 0.0281 - val_loss: 0.0280 Epoch 15/100 9413/9413 [==============================] - 4s - loss: 0.0276 - val_loss: 0.0274 Epoch 16/100 9413/9413 [==============================] - 4s - loss: 0.0270 - val_loss: 0.0268 Epoch 17/100 9413/9413 [==============================] - 4s - loss: 0.0264 - val_loss: 0.0262 Epoch 18/100 9413/9413 [==============================] - 4s - loss: 0.0258 - val_loss: 0.0257 Epoch 19/100 9413/9413 [==============================] - 4s - loss: 0.0253 - val_loss: 0.0252 Epoch 20/100 9413/9413 [==============================] - 4s - loss: 0.0248 - val_loss: 0.0247 Epoch 21/100 9413/9413 [==============================] - 4s - loss: 0.0244 - val_loss: 0.0243 Epoch 22/100 9413/9413 [==============================] - 4s - loss: 0.0240 - val_loss: 0.0239 Epoch 23/100 9413/9413 [==============================] - 4s - loss: 0.0237 - val_loss: 0.0236 Epoch 24/100 9413/9413 [==============================] - 4s - loss: 0.0233 - val_loss: 0.0233 Epoch 25/100 9413/9413 [==============================] - 4s - loss: 0.0230 - val_loss: 0.0230 Epoch 26/100 9413/9413 [==============================] - 4s - loss: 0.0227 - val_loss: 0.0227 Epoch 27/100 9413/9413 [==============================] - 4s - loss: 0.0225 - val_loss: 0.0224 Epoch 28/100 9413/9413 [==============================] - 5s - loss: 0.0222 - val_loss: 0.0222 Epoch 29/100 9413/9413 [==============================] - 5s - loss: 0.0219 - val_loss: 0.0219 Epoch 30/100 9413/9413 [==============================] - 5s - loss: 0.0217 - val_loss: 0.0217 Epoch 31/100 9413/9413 [==============================] - 5s - loss: 0.0215 - val_loss: 0.0214 Epoch 32/100 9413/9413 [==============================] - 6s - loss: 0.0212 - val_loss: 0.0212 Epoch 33/100 9413/9413 [==============================] - 6s - loss: 0.0210 - val_loss: 0.0209 Epoch 34/100 9413/9413 [==============================] - 6s - loss: 0.0208 - val_loss: 0.0207 Epoch 35/100 9413/9413 [==============================] - 6s - loss: 0.0205 - val_loss: 0.0205 Epoch 36/100 9413/9413 [==============================] - 6s - loss: 0.0203 - val_loss: 0.0203 Epoch 37/100 9413/9413 [==============================] - 6s - loss: 0.0201 - val_loss: 0.0201 Epoch 38/100 9413/9413 [==============================] - 6s - loss: 0.0199 - val_loss: 0.0199 Epoch 39/100 9413/9413 [==============================] - 6s - loss: 0.0197 - val_loss: 0.0197 Epoch 40/100 9413/9413 [==============================] - 6s - loss: 0.0196 - val_loss: 0.0196 Epoch 41/100 9413/9413 [==============================] - 6s - loss: 0.0194 - val_loss: 0.0194 Epoch 42/100 9413/9413 [==============================] - 6s - loss: 0.0192 - val_loss: 0.0192 Epoch 43/100 9413/9413 [==============================] - 6s - loss: 0.0191 - val_loss: 0.0191 Epoch 44/100 9413/9413 [==============================] - 6s - loss: 0.0189 - val_loss: 0.0189 Epoch 45/100 9413/9413 [==============================] - 6s - loss: 0.0188 - val_loss: 0.0188 Epoch 46/100 9413/9413 [==============================] - 6s - loss: 0.0186 - val_loss: 0.0186 Epoch 47/100 9413/9413 [==============================] - 6s - loss: 0.0185 - val_loss: 0.0185 Epoch 48/100 9413/9413 [==============================] - 6s - loss: 0.0183 - val_loss: 0.0184 Epoch 49/100 9413/9413 [==============================] - 6s - loss: 0.0182 - val_loss: 0.0182 Epoch 50/100 9413/9413 [==============================] - 6s - loss: 0.0181 - val_loss: 0.0181 Epoch 51/100 9413/9413 [==============================] - 6s - loss: 0.0179 - val_loss: 0.0180 Epoch 52/100 9413/9413 [==============================] - 6s - loss: 0.0178 - val_loss: 0.0179 Epoch 53/100 9413/9413 [==============================] - 4s - loss: 0.0177 - val_loss: 0.0177 Epoch 54/100 9413/9413 [==============================] - 5s - loss: 0.0176 - val_loss: 0.0176 Epoch 55/100 9413/9413 [==============================] - 4s - loss: 0.0175 - val_loss: 0.0175 Epoch 56/100 9413/9413 [==============================] - 6s - loss: 0.0174 - val_loss: 0.0174 Epoch 57/100 9413/9413 [==============================] - 6s - loss: 0.0173 - val_loss: 0.0173 Epoch 58/100 9413/9413 [==============================] - 5s - loss: 0.0172 - val_loss: 0.0172 Epoch 59/100 9413/9413 [==============================] - 6s - loss: 0.0171 - val_loss: 0.0171 Epoch 60/100 9413/9413 [==============================] - 5s - loss: 0.0170 - val_loss: 0.0170 Epoch 61/100 9413/9413 [==============================] - 4s - loss: 0.0169 - val_loss: 0.0169 Epoch 62/100 9413/9413 [==============================] - 4s - loss: 0.0168 - val_loss: 0.0169 Epoch 63/100 9413/9413 [==============================] - 4s - loss: 0.0167 - val_loss: 0.0168 Epoch 64/100 9413/9413 [==============================] - 4s - loss: 0.0166 - val_loss: 0.0167 Epoch 65/100 9413/9413 [==============================] - 4s - loss: 0.0166 - val_loss: 0.0166 Epoch 66/100 9413/9413 [==============================] - 4s - loss: 0.0165 - val_loss: 0.0165 Epoch 67/100 9413/9413 [==============================] - 4s - loss: 0.0164 - val_loss: 0.0165 Epoch 68/100 9413/9413 [==============================] - 4s - loss: 0.0164 - val_loss: 0.0164 Epoch 69/100 9413/9413 [==============================] - 4s - loss: 0.0163 - val_loss: 0.0163 Epoch 70/100 9413/9413 [==============================] - 4s - loss: 0.0162 - val_loss: 0.0163 Epoch 71/100 9413/9413 [==============================] - 4s - loss: 0.0161 - val_loss: 0.0162 Epoch 72/100 9413/9413 [==============================] - 4s - loss: 0.0161 - val_loss: 0.0161 Epoch 73/100 9413/9413 [==============================] - 4s - loss: 0.0160 - val_loss: 0.0161 Epoch 74/100 9413/9413 [==============================] - 4s - loss: 0.0160 - val_loss: 0.0160 Epoch 75/100 9413/9413 [==============================] - 5s - loss: 0.0159 - val_loss: 0.0160 Epoch 76/100 9413/9413 [==============================] - 5s - loss: 0.0158 - val_loss: 0.0159 Epoch 77/100 9413/9413 [==============================] - 4s - loss: 0.0158 - val_loss: 0.0158 Epoch 78/100 9413/9413 [==============================] - 4s - loss: 0.0157 - val_loss: 0.0158 Epoch 79/100 9413/9413 [==============================] - 4s - loss: 0.0157 - val_loss: 0.0157 Epoch 80/100 9413/9413 [==============================] - 4s - loss: 0.0156 - val_loss: 0.0157 Epoch 81/100 9413/9413 [==============================] - 4s - loss: 0.0156 - val_loss: 0.0156 Epoch 82/100 9413/9413 [==============================] - 4s - loss: 0.0155 - val_loss: 0.0156 Epoch 83/100 9413/9413 [==============================] - 4s - loss: 0.0155 - val_loss: 0.0155 Epoch 84/100 9413/9413 [==============================] - 4s - loss: 0.0154 - val_loss: 0.0155 Epoch 85/100 9413/9413 [==============================] - 4s - loss: 0.0154 - val_loss: 0.0154 Epoch 86/100 9413/9413 [==============================] - 4s - loss: 0.0153 - val_loss: 0.0154 Epoch 87/100 9413/9413 [==============================] - 4s - loss: 0.0153 - val_loss: 0.0153 Epoch 88/100 9413/9413 [==============================] - 4s - loss: 0.0152 - val_loss: 0.0153 Epoch 89/100 9413/9413 [==============================] - 4s - loss: 0.0152 - val_loss: 0.0153 Epoch 90/100 9413/9413 [==============================] - 4s - loss: 0.0151 - val_loss: 0.0152 Epoch 91/100 9413/9413 [==============================] - 4s - loss: 0.0151 - val_loss: 0.0152 Epoch 92/100 9413/9413 [==============================] - 5s - loss: 0.0151 - val_loss: 0.0151 Epoch 93/100 9413/9413 [==============================] - 5s - loss: 0.0150 - val_loss: 0.0151 Epoch 94/100 9413/9413 [==============================] - 5s - loss: 0.0150 - val_loss: 0.0150 Epoch 95/100 9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0150 Epoch 96/100 9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0150 Epoch 97/100 9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0149 Epoch 98/100 9413/9413 [==============================] - 5s - loss: 0.0148 - val_loss: 0.0149 Epoch 99/100 9413/9413 [==============================] - 5s - loss: 0.0148 - val_loss: 0.0148 Epoch 100/100 9413/9413 [==============================] - 5s - loss: 0.0147 - val_loss: 0.0148 CPU times: user 17min 56s, sys: 2min 47s, total: 20min 43s Wall time: 8min 46s
# Visualize training performance
history_df = pd.DataFrame(hist.history)
hist_plot_file = os.path.join('figures', 'adage_training.png')
ax = history_df.plot()
ax.set_xlabel('Epochs')
ax.set_ylabel('ADAGE Reconstruction Loss')
fig = ax.get_figure()
fig.savefig(hist_plot_file)
# Encode rnaseq into the hidden/latent representation - and save output
encoded_samples = encoder.predict(np.array(rnaseq_df))
encoded_rnaseq_df = pd.DataFrame(encoded_samples, index=rnaseq_df.index)
encoded_rnaseq_df.columns.name = 'sample_id'
encoded_rnaseq_df.columns = encoded_rnaseq_df.columns + 1
encoded_file = os.path.join('data', 'encoded_adage_features.tsv')
encoded_rnaseq_df.to_csv(encoded_file, sep='\t')
# Output weight matrix of gene contributions per node
weight_file = os.path.join('results', 'adage_gene_weights.tsv')
weight_matrix = pd.DataFrame(autoencoder.get_weights()[0], index=rnaseq_df.columns,
columns=range(1, 101)).T
weight_matrix.index.name = 'encodings'
weight_matrix.to_csv(weight_file, sep='\t')
# Reconstruct input RNAseq
decoded_samples = decoder.predict(encoded_samples)
reconstructed_df = pd.DataFrame(decoded_samples, index=rnaseq_df.index,
columns=rnaseq_df.columns)
reconstruction_fidelity = rnaseq_df - reconstructed_df
gene_mean = reconstruction_fidelity.mean(axis=0)
gene_abssum = reconstruction_fidelity.abs().sum(axis=0).divide(rnaseq_df.shape[0])
gene_summary = pd.DataFrame([gene_mean, gene_abssum], index=['gene mean', 'gene abs(sum)']).T
gene_summary.sort_values(by='gene abs(sum)', ascending=False).head()
gene mean | gene abs(sum) | |
---|---|---|
PPAN-P2RY11 | -0.001976 | 0.239137 |
GSTT1 | 0.023419 | 0.232442 |
GSTM1 | 0.005316 | 0.219905 |
EIF1AY | -0.048686 | 0.209455 |
DDX3Y | -0.034089 | 0.207383 |
# Mean of gene reconstruction vs. absolute reconstructed difference per sample
reconstruct_fig_file = os.path.join('figures', 'adage_gene_reconstruction.png')
g = sns.jointplot('gene mean', 'gene abs(sum)', data=gene_summary, stat_func=None);
g.savefig(reconstruct_fig_file)
# What are the most and least activated nodes
sum_node_activity = encoded_rnaseq_df.sum(axis=0).sort_values(ascending=False)
# Top 5 most active nodes
print(sum_node_activity.head(5))
# Bottom 10 least active nodes
sum_node_activity.tail(5)
sample_id 32 116126.578125 16 111204.562500 80 97674.945312 66 96646.546875 40 95180.679688 dtype: float32
sample_id 25 -11347.204102 24 -12289.624023 87 -12757.212891 41 -13314.549805 49 -16946.775391 dtype: float32
# Histogram of node activity for all 100 latent features
sum_node_activity.hist()
plt.xlabel('Activation Sum')
plt.ylabel('Count');
# Histogram of sample activity for all 10,000 samples
encoded_rnaseq_df.sum(axis=1).hist()
plt.xlabel('Sample Total Activation')
plt.ylabel('Count');
# Example of node activation distribution for the first two latent features
plt.figure(figsize=(6, 6))
plt.scatter(encoded_rnaseq_df.iloc[:, 1], encoded_rnaseq_df.iloc[:, 2])
plt.xlabel('Latent Feature 1')
plt.ylabel('Latent Feature 2');