ADAGE: Pan-cancer gene expression

Gregory Way 2017

This script trains a denoising autoencoder for cancer gene expression data using Keras. It modifies the framework presented by the ADAGE (Analysis using denoising autoencoders of gene expression) model published by Tan et al 2015.

An ADAGE model learns a non-linear, reduced dimensional representation of gene expression data by bottlenecking raw features into a smaller set. The model is then trained by minimizing the information lost between input and reconstructed input.

The specific model trained in this notebook consists of gene expression input (5000 most variably expressed genes by median absolute deviation) compressed down into one length 100 vector. The hidden layer is then decoded back to the original 5000 dimensions. The encoding (compression) layer has a relu activation and the decoding layer has a sigmoid activation. The weights of each layer are glorot uniform initialized. We include an l1 regularization term (see keras.regularizers.l1 for more details) to induce sparsity in the model, as well as a term controlling the probability of input feature dropout. This is only active during training and is the denoising aspect of the model. See keras.layers.noise.Dropout for more details.

We train the autoencoder with the Adadelta optimizer and MSE reconstruction loss.

The pan-cancer ADAGE model is similar to tybalt, but does not constrain the features to match a Gaussian distribution. It is an active research question if the VAE learned features provide any additional benefits over ADAGE features. The VAE is a generative model and therefore permits easy generation of fake data. Additionally, we hypothesize that the VAE learns a manifold that can be interpolated to extract meaningful biological knowledge.

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import pydot
import graphviz
from keras.utils import plot_model
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

from keras.layers import Input, Dense, Dropout, Activation
from keras.layers.noise import GaussianDropout
from keras.models import Model
from keras.regularizers import l1
from keras import optimizers
import keras
Using TensorFlow backend.
In [2]:
print(keras.__version__)
2.0.6
In [3]:
%matplotlib inline
plt.style.use('seaborn-notebook')
In [4]:
sns.set(style="white", color_codes=True)
sns.set_context("paper", rc={"font.size":14,"axes.titlesize":15,"axes.labelsize":20,
                             'xtick.labelsize':14, 'ytick.labelsize':14})
In [5]:
# Load RNAseq data
rnaseq_file = os.path.join('data', 'pancan_scaled_zeroone_rnaseq.tsv.gz')
rnaseq_df = pd.read_table(rnaseq_file, index_col=0)
print(rnaseq_df.shape)
rnaseq_df.head(2)
(10459, 5000)
Out[5]:
RPS4Y1 XIST KRT5 AGR2 CEACAM5 KRT6A KRT14 CEACAM6 DDX3Y KDM5D ... FAM129A C8orf48 CDK5R1 FAM81A C13orf18 GDPD3 SMAGP C2orf85 POU5F1B CHST2
TCGA-02-0047-01 0.678296 0.289910 0.034230 0.0 0.0 0.084731 0.031863 0.037709 0.746797 0.687833 ... 0.440610 0.428782 0.732819 0.634340 0.580662 0.294313 0.458134 0.478219 0.168263 0.638497
TCGA-02-0055-01 0.200633 0.654917 0.181993 0.0 0.0 0.100606 0.050011 0.092586 0.103725 0.140642 ... 0.620658 0.363207 0.592269 0.602755 0.610192 0.374569 0.722420 0.271356 0.160465 0.602560

2 rows × 5000 columns

In [6]:
np.random.seed(123)
In [7]:
# Split 10% test set randomly
test_set_percent = 0.1
rnaseq_test_df = rnaseq_df.sample(frac=test_set_percent)
rnaseq_train_df = rnaseq_df.drop(rnaseq_test_df.index)

Parameter Sweep Results

We previously performed a parameter sweep search over a grid of potential hyperparameter values. Based on this sweep, we determined that the optimal ADAGE parameters are:

Parameter Optimal Setting
Learning Rate 1.1
Sparsity 0
Noise 0.05
Epochs 100
Batch Size 50
In [8]:
num_features = rnaseq_df.shape[1]
encoding_dim = 100
sparsity = 0
noise = 0.05
epochs = 100
batch_size = 50
learning_rate = 1.1
In [9]:
# Build the Keras graph
input_rnaseq = Input(shape=(num_features, ))
encoded_rnaseq = Dropout(noise)(input_rnaseq)
encoded_rnaseq_2 = Dense(encoding_dim,
                         activity_regularizer=l1(sparsity))(encoded_rnaseq)
activation = Activation('relu')(encoded_rnaseq_2)
decoded_rnaseq = Dense(num_features, activation='sigmoid')(activation)

autoencoder = Model(input_rnaseq, decoded_rnaseq)
In [10]:
autoencoder.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 5000)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 5000)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 100)               500100    
_________________________________________________________________
activation_1 (Activation)    (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 5000)              505000    
=================================================================
Total params: 1,005,100
Trainable params: 1,005,100
Non-trainable params: 0
_________________________________________________________________
In [11]:
# Visualize the connections of the custom VAE model
output_model_file = os.path.join('figures', 'adage_architecture.png')
plot_model(autoencoder, to_file=output_model_file)

SVG(model_to_dot(autoencoder).create(prog='dot', format='svg'))
Out[11]:
G 139712716571032 input_1: InputLayer 139712716571256 dropout_1: Dropout 139712716571032->139712716571256 139712716702384 dense_1: Dense 139712716571256->139712716702384 139712716703280 activation_1: Activation 139712716702384->139712716703280 139712716572712 dense_2: Dense 139712716703280->139712716572712
In [12]:
# Separate out the encoder and decoder model
encoder = Model(input_rnaseq, encoded_rnaseq_2)

encoded_input = Input(shape=(encoding_dim, ))
decoder_layer = autoencoder.layers[-1]
decoder = Model(encoded_input, decoder_layer(encoded_input))
In [13]:
# Compile the autoencoder to prepare for training
adadelta = optimizers.Adadelta(lr=learning_rate)
autoencoder.compile(optimizer=adadelta, loss='mse')
In [14]:
%%time
hist = autoencoder.fit(np.array(rnaseq_train_df), np.array(rnaseq_train_df),
                       shuffle=True,
                       epochs=epochs,
                       batch_size=batch_size,
                       validation_data=(np.array(rnaseq_test_df), np.array(rnaseq_test_df)))
Train on 9413 samples, validate on 1046 samples
Epoch 1/100
9413/9413 [==============================] - 4s - loss: 0.0570 - val_loss: 0.0365
Epoch 2/100
9413/9413 [==============================] - 4s - loss: 0.0358 - val_loss: 0.0359
Epoch 3/100
9413/9413 [==============================] - 5s - loss: 0.0355 - val_loss: 0.0357
Epoch 4/100
9413/9413 [==============================] - 6s - loss: 0.0352 - val_loss: 0.0353
Epoch 5/100
9413/9413 [==============================] - 5s - loss: 0.0348 - val_loss: 0.0348
Epoch 6/100
9413/9413 [==============================] - 4s - loss: 0.0342 - val_loss: 0.0341
Epoch 7/100
9413/9413 [==============================] - 5s - loss: 0.0334 - val_loss: 0.0331
Epoch 8/100
9413/9413 [==============================] - 4s - loss: 0.0323 - val_loss: 0.0321
Epoch 9/100
9413/9413 [==============================] - 4s - loss: 0.0313 - val_loss: 0.0311
Epoch 10/100
9413/9413 [==============================] - 4s - loss: 0.0305 - val_loss: 0.0303
Epoch 11/100
9413/9413 [==============================] - 4s - loss: 0.0298 - val_loss: 0.0296
Epoch 12/100
9413/9413 [==============================] - 4s - loss: 0.0292 - val_loss: 0.0291
Epoch 13/100
9413/9413 [==============================] - 4s - loss: 0.0287 - val_loss: 0.0285
Epoch 14/100
9413/9413 [==============================] - 4s - loss: 0.0281 - val_loss: 0.0280
Epoch 15/100
9413/9413 [==============================] - 4s - loss: 0.0276 - val_loss: 0.0274
Epoch 16/100
9413/9413 [==============================] - 4s - loss: 0.0270 - val_loss: 0.0268
Epoch 17/100
9413/9413 [==============================] - 4s - loss: 0.0264 - val_loss: 0.0262
Epoch 18/100
9413/9413 [==============================] - 4s - loss: 0.0258 - val_loss: 0.0257
Epoch 19/100
9413/9413 [==============================] - 4s - loss: 0.0253 - val_loss: 0.0252
Epoch 20/100
9413/9413 [==============================] - 4s - loss: 0.0248 - val_loss: 0.0247
Epoch 21/100
9413/9413 [==============================] - 4s - loss: 0.0244 - val_loss: 0.0243
Epoch 22/100
9413/9413 [==============================] - 4s - loss: 0.0240 - val_loss: 0.0239
Epoch 23/100
9413/9413 [==============================] - 4s - loss: 0.0237 - val_loss: 0.0236
Epoch 24/100
9413/9413 [==============================] - 4s - loss: 0.0233 - val_loss: 0.0233
Epoch 25/100
9413/9413 [==============================] - 4s - loss: 0.0230 - val_loss: 0.0230
Epoch 26/100
9413/9413 [==============================] - 4s - loss: 0.0227 - val_loss: 0.0227
Epoch 27/100
9413/9413 [==============================] - 4s - loss: 0.0225 - val_loss: 0.0224
Epoch 28/100
9413/9413 [==============================] - 5s - loss: 0.0222 - val_loss: 0.0222
Epoch 29/100
9413/9413 [==============================] - 5s - loss: 0.0219 - val_loss: 0.0219
Epoch 30/100
9413/9413 [==============================] - 5s - loss: 0.0217 - val_loss: 0.0217
Epoch 31/100
9413/9413 [==============================] - 5s - loss: 0.0215 - val_loss: 0.0214
Epoch 32/100
9413/9413 [==============================] - 6s - loss: 0.0212 - val_loss: 0.0212
Epoch 33/100
9413/9413 [==============================] - 6s - loss: 0.0210 - val_loss: 0.0209
Epoch 34/100
9413/9413 [==============================] - 6s - loss: 0.0208 - val_loss: 0.0207
Epoch 35/100
9413/9413 [==============================] - 6s - loss: 0.0205 - val_loss: 0.0205
Epoch 36/100
9413/9413 [==============================] - 6s - loss: 0.0203 - val_loss: 0.0203
Epoch 37/100
9413/9413 [==============================] - 6s - loss: 0.0201 - val_loss: 0.0201
Epoch 38/100
9413/9413 [==============================] - 6s - loss: 0.0199 - val_loss: 0.0199
Epoch 39/100
9413/9413 [==============================] - 6s - loss: 0.0197 - val_loss: 0.0197
Epoch 40/100
9413/9413 [==============================] - 6s - loss: 0.0196 - val_loss: 0.0196
Epoch 41/100
9413/9413 [==============================] - 6s - loss: 0.0194 - val_loss: 0.0194
Epoch 42/100
9413/9413 [==============================] - 6s - loss: 0.0192 - val_loss: 0.0192
Epoch 43/100
9413/9413 [==============================] - 6s - loss: 0.0191 - val_loss: 0.0191
Epoch 44/100
9413/9413 [==============================] - 6s - loss: 0.0189 - val_loss: 0.0189
Epoch 45/100
9413/9413 [==============================] - 6s - loss: 0.0188 - val_loss: 0.0188
Epoch 46/100
9413/9413 [==============================] - 6s - loss: 0.0186 - val_loss: 0.0186
Epoch 47/100
9413/9413 [==============================] - 6s - loss: 0.0185 - val_loss: 0.0185
Epoch 48/100
9413/9413 [==============================] - 6s - loss: 0.0183 - val_loss: 0.0184
Epoch 49/100
9413/9413 [==============================] - 6s - loss: 0.0182 - val_loss: 0.0182
Epoch 50/100
9413/9413 [==============================] - 6s - loss: 0.0181 - val_loss: 0.0181
Epoch 51/100
9413/9413 [==============================] - 6s - loss: 0.0179 - val_loss: 0.0180
Epoch 52/100
9413/9413 [==============================] - 6s - loss: 0.0178 - val_loss: 0.0179
Epoch 53/100
9413/9413 [==============================] - 4s - loss: 0.0177 - val_loss: 0.0177
Epoch 54/100
9413/9413 [==============================] - 5s - loss: 0.0176 - val_loss: 0.0176
Epoch 55/100
9413/9413 [==============================] - 4s - loss: 0.0175 - val_loss: 0.0175
Epoch 56/100
9413/9413 [==============================] - 6s - loss: 0.0174 - val_loss: 0.0174
Epoch 57/100
9413/9413 [==============================] - 6s - loss: 0.0173 - val_loss: 0.0173
Epoch 58/100
9413/9413 [==============================] - 5s - loss: 0.0172 - val_loss: 0.0172
Epoch 59/100
9413/9413 [==============================] - 6s - loss: 0.0171 - val_loss: 0.0171
Epoch 60/100
9413/9413 [==============================] - 5s - loss: 0.0170 - val_loss: 0.0170
Epoch 61/100
9413/9413 [==============================] - 4s - loss: 0.0169 - val_loss: 0.0169
Epoch 62/100
9413/9413 [==============================] - 4s - loss: 0.0168 - val_loss: 0.0169
Epoch 63/100
9413/9413 [==============================] - 4s - loss: 0.0167 - val_loss: 0.0168
Epoch 64/100
9413/9413 [==============================] - 4s - loss: 0.0166 - val_loss: 0.0167
Epoch 65/100
9413/9413 [==============================] - 4s - loss: 0.0166 - val_loss: 0.0166
Epoch 66/100
9413/9413 [==============================] - 4s - loss: 0.0165 - val_loss: 0.0165
Epoch 67/100
9413/9413 [==============================] - 4s - loss: 0.0164 - val_loss: 0.0165
Epoch 68/100
9413/9413 [==============================] - 4s - loss: 0.0164 - val_loss: 0.0164
Epoch 69/100
9413/9413 [==============================] - 4s - loss: 0.0163 - val_loss: 0.0163
Epoch 70/100
9413/9413 [==============================] - 4s - loss: 0.0162 - val_loss: 0.0163
Epoch 71/100
9413/9413 [==============================] - 4s - loss: 0.0161 - val_loss: 0.0162
Epoch 72/100
9413/9413 [==============================] - 4s - loss: 0.0161 - val_loss: 0.0161
Epoch 73/100
9413/9413 [==============================] - 4s - loss: 0.0160 - val_loss: 0.0161
Epoch 74/100
9413/9413 [==============================] - 4s - loss: 0.0160 - val_loss: 0.0160
Epoch 75/100
9413/9413 [==============================] - 5s - loss: 0.0159 - val_loss: 0.0160
Epoch 76/100
9413/9413 [==============================] - 5s - loss: 0.0158 - val_loss: 0.0159
Epoch 77/100
9413/9413 [==============================] - 4s - loss: 0.0158 - val_loss: 0.0158
Epoch 78/100
9413/9413 [==============================] - 4s - loss: 0.0157 - val_loss: 0.0158
Epoch 79/100
9413/9413 [==============================] - 4s - loss: 0.0157 - val_loss: 0.0157
Epoch 80/100
9413/9413 [==============================] - 4s - loss: 0.0156 - val_loss: 0.0157
Epoch 81/100
9413/9413 [==============================] - 4s - loss: 0.0156 - val_loss: 0.0156
Epoch 82/100
9413/9413 [==============================] - 4s - loss: 0.0155 - val_loss: 0.0156
Epoch 83/100
9413/9413 [==============================] - 4s - loss: 0.0155 - val_loss: 0.0155
Epoch 84/100
9413/9413 [==============================] - 4s - loss: 0.0154 - val_loss: 0.0155
Epoch 85/100
9413/9413 [==============================] - 4s - loss: 0.0154 - val_loss: 0.0154
Epoch 86/100
9413/9413 [==============================] - 4s - loss: 0.0153 - val_loss: 0.0154
Epoch 87/100
9413/9413 [==============================] - 4s - loss: 0.0153 - val_loss: 0.0153
Epoch 88/100
9413/9413 [==============================] - 4s - loss: 0.0152 - val_loss: 0.0153
Epoch 89/100
9413/9413 [==============================] - 4s - loss: 0.0152 - val_loss: 0.0153
Epoch 90/100
9413/9413 [==============================] - 4s - loss: 0.0151 - val_loss: 0.0152
Epoch 91/100
9413/9413 [==============================] - 4s - loss: 0.0151 - val_loss: 0.0152
Epoch 92/100
9413/9413 [==============================] - 5s - loss: 0.0151 - val_loss: 0.0151
Epoch 93/100
9413/9413 [==============================] - 5s - loss: 0.0150 - val_loss: 0.0151
Epoch 94/100
9413/9413 [==============================] - 5s - loss: 0.0150 - val_loss: 0.0150
Epoch 95/100
9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0150
Epoch 96/100
9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0150
Epoch 97/100
9413/9413 [==============================] - 5s - loss: 0.0149 - val_loss: 0.0149
Epoch 98/100
9413/9413 [==============================] - 5s - loss: 0.0148 - val_loss: 0.0149
Epoch 99/100
9413/9413 [==============================] - 5s - loss: 0.0148 - val_loss: 0.0148
Epoch 100/100
9413/9413 [==============================] - 5s - loss: 0.0147 - val_loss: 0.0148
CPU times: user 17min 56s, sys: 2min 47s, total: 20min 43s
Wall time: 8min 46s
In [15]:
# Visualize training performance
history_df = pd.DataFrame(hist.history)
hist_plot_file = os.path.join('figures', 'adage_training.png')
ax = history_df.plot()
ax.set_xlabel('Epochs')
ax.set_ylabel('ADAGE Reconstruction Loss')
fig = ax.get_figure()
fig.savefig(hist_plot_file)

Save Model Outputs

In [16]:
# Encode rnaseq into the hidden/latent representation - and save output
encoded_samples = encoder.predict(np.array(rnaseq_df))
encoded_rnaseq_df = pd.DataFrame(encoded_samples, index=rnaseq_df.index)

encoded_rnaseq_df.columns.name = 'sample_id'
encoded_rnaseq_df.columns = encoded_rnaseq_df.columns + 1
encoded_file = os.path.join('data', 'encoded_adage_features.tsv')
encoded_rnaseq_df.to_csv(encoded_file, sep='\t')
In [17]:
# Output weight matrix of gene contributions per node
weight_file = os.path.join('results', 'adage_gene_weights.tsv')

weight_matrix = pd.DataFrame(autoencoder.get_weights()[0], index=rnaseq_df.columns,
                             columns=range(1, 101)).T
weight_matrix.index.name = 'encodings'
weight_matrix.to_csv(weight_file, sep='\t')

Observe reconstruction

In [18]:
# Reconstruct input RNAseq
decoded_samples = decoder.predict(encoded_samples)

reconstructed_df = pd.DataFrame(decoded_samples, index=rnaseq_df.index,
                                columns=rnaseq_df.columns)
In [19]:
reconstruction_fidelity = rnaseq_df - reconstructed_df

gene_mean = reconstruction_fidelity.mean(axis=0)
gene_abssum = reconstruction_fidelity.abs().sum(axis=0).divide(rnaseq_df.shape[0])
gene_summary = pd.DataFrame([gene_mean, gene_abssum], index=['gene mean', 'gene abs(sum)']).T
gene_summary.sort_values(by='gene abs(sum)', ascending=False).head()
Out[19]:
gene mean gene abs(sum)
PPAN-P2RY11 -0.001976 0.239137
GSTT1 0.023419 0.232442
GSTM1 0.005316 0.219905
EIF1AY -0.048686 0.209455
DDX3Y -0.034089 0.207383
In [20]:
# Mean of gene reconstruction vs. absolute reconstructed difference per sample
reconstruct_fig_file = os.path.join('figures', 'adage_gene_reconstruction.png')
g = sns.jointplot('gene mean', 'gene abs(sum)', data=gene_summary, stat_func=None);
g.savefig(reconstruct_fig_file)
In [21]:
# What are the most and least activated nodes
sum_node_activity = encoded_rnaseq_df.sum(axis=0).sort_values(ascending=False)

# Top 5 most active nodes
print(sum_node_activity.head(5))

# Bottom 10 least active nodes
sum_node_activity.tail(5)
sample_id
32    116126.578125
16    111204.562500
80     97674.945312
66     96646.546875
40     95180.679688
dtype: float32
Out[21]:
sample_id
25   -11347.204102
24   -12289.624023
87   -12757.212891
41   -13314.549805
49   -16946.775391
dtype: float32
In [22]:
# Histogram of node activity for all 100 latent features
sum_node_activity.hist()
plt.xlabel('Activation Sum')
plt.ylabel('Count');
In [23]:
# Histogram of sample activity for all 10,000 samples
encoded_rnaseq_df.sum(axis=1).hist()
plt.xlabel('Sample Total Activation')
plt.ylabel('Count');
In [24]:
# Example of node activation distribution for the first two latent features
plt.figure(figsize=(6, 6))
plt.scatter(encoded_rnaseq_df.iloc[:, 1], encoded_rnaseq_df.iloc[:, 2])
plt.xlabel('Latent Feature 1')
plt.ylabel('Latent Feature 2');