import keras
from keras.models import Sequential, Model, load_model
import os
import pickle
import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.io as spio
import matplotlib.pyplot as plt
from scrambler.models import *
from scrambler.utils import OneHotEncoder, get_sequence_masks
from scrambler.visualizations import plot_dna_logo, plot_dna_importance_scores
from optimus5_utils import load_optimus5_data, load_optimus5_predictor, animate_optimus5_examples
Using TensorFlow backend.
#Load Optimus-5 data and predictor
encoder = OneHotEncoder(seq_length=50, channel_map={'A' : 0, 'C' : 1, 'G' : 2, 'T' : 3})
train_data_path = 'bottom5KIFuAUGTop5KIFuAUG.csv'
test_data_path = 'randomSampleTestingAllAUGtypes.csv'
x_train, y_train, x_test, y_test = load_optimus5_data(train_data_path, test_data_path)
predictor_path = 'saved_models/optimusRetrainedMain.hdf5'
predictor = load_optimus5_predictor(predictor_path)
x_train.shape = (15008, 1, 50, 4) x_test.shape = (3200, 1, 50, 4) y_train.shape = (15008, 1) y_test.shape = (3200, 1)
#Define sequence template and background
sequence_template = '$' * 50
pseudo_count = 1.0
onehot_template = encoder(sequence_template)[None, ...]
sequence_mask = get_sequence_masks([sequence_template])[0]
x_mean = (np.sum(x_train, axis=(0, 1)) + pseudo_count) / (x_train.shape[0] + 4. * pseudo_count)
#Visualize background sequence distribution
plot_dna_logo(np.copy(x_mean), sequence_template=sequence_template, figsize=(10, 1), logo_height=1.0, plot_start=0, plot_end=50)
#Calculate mean training set kl-divergence against background
x_train_clipped = np.clip(np.copy(x_train[:, 0, :, :]), 1e-8, 1. - 1e-8)
kl_divs = np.sum(x_train_clipped * np.log(x_train_clipped / np.tile(np.expand_dims(x_mean, axis=0), (x_train_clipped.shape[0], 1, 1))), axis=-1) / np.log(2.0)
x_mean_kl_divs = np.sum(kl_divs * sequence_mask, axis=-1) / np.sum(sequence_mask)
x_mean_kl_div = np.mean(x_mean_kl_divs)
print("Mean KL Div against background (bits) = " + str(x_mean_kl_div))
Mean KL Div against background (bits) = 1.9679329305814974
#Build scrambler
#Scrambler network configuration
network_config = {
'n_groups' : 5,
'n_resblocks_per_group' : 4,
'n_channels' : 32,
'window_size' : 3,
'dilation_rates' : [1, 2, 4, 2, 1],
'drop_rate' : 0.0,
'norm_mode' : 'instance',
'mask_smoothing' : False,
'mask_smoothing_window_size' : 5,
'mask_smoothing_std' : 1.,
'mask_drop_scales' : [1, 5],
'mask_min_drop_rate' : 0.0,
'mask_max_drop_rate' : 0.5,
'label_input' : False
}
#Initialize scrambler
scrambler = Scrambler(
scrambler_mode='inclusion',
input_size_x=1,
input_size_y=50,
n_out_channels=4,
input_templates=[onehot_template],
input_backgrounds=[x_mean],
batch_size=32,
n_samples=32,
sample_mode='gumbel',
zeropad_input=False,
mask_dropout=False,
network_config=network_config
)
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons If you depend on functionality not listed there, please file an issue.
#Train scrambler
n_epochs = 10
train_history = scrambler.train(
predictor,
x_train,
y_train,
x_test,
y_test,
n_epochs,
monitor_test_indices=np.arange(32).tolist(),
monitor_batch_freq_dict={0 : 1, 100 : 5, 469 : 10},
nll_mode='reconstruction',
predictor_task='regression',
entropy_mode='target',
entropy_bits=0.125,
entropy_weight=10.
)
Train on 15008 samples, validate on 3200 samples Epoch 1/10 15008/15008 [==============================] - 82s 5ms/step - loss: 1.5905 - nll_loss: 1.4059 - entropy_loss: 0.1846 - val_loss: 0.8265 - val_nll_loss: 0.7049 - val_entropy_loss: 0.1216 Epoch 2/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.9412 - nll_loss: 0.8684 - entropy_loss: 0.0728 - val_loss: 0.6593 - val_nll_loss: 0.5736 - val_entropy_loss: 0.0857 Epoch 3/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.8443 - nll_loss: 0.7620 - entropy_loss: 0.0823 - val_loss: 0.6611 - val_nll_loss: 0.5764 - val_entropy_loss: 0.0847 Epoch 4/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.7730 - nll_loss: 0.6873 - entropy_loss: 0.0858 - val_loss: 0.6158 - val_nll_loss: 0.4936 - val_entropy_loss: 0.1222 Epoch 5/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.7396 - nll_loss: 0.6528 - entropy_loss: 0.0868 - val_loss: 0.6604 - val_nll_loss: 0.5796 - val_entropy_loss: 0.0808 Epoch 6/10 15008/15008 [==============================] - 61s 4ms/step - loss: 0.6970 - nll_loss: 0.6152 - entropy_loss: 0.0819 - val_loss: 0.6569 - val_nll_loss: 0.5934 - val_entropy_loss: 0.0634 Epoch 7/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.6697 - nll_loss: 0.5914 - entropy_loss: 0.0783 - val_loss: 0.6612 - val_nll_loss: 0.5859 - val_entropy_loss: 0.0753 Epoch 8/10 15008/15008 [==============================] - 60s 4ms/step - loss: 0.6474 - nll_loss: 0.5716 - entropy_loss: 0.0758 - val_loss: 0.6661 - val_nll_loss: 0.6020 - val_entropy_loss: 0.0641 Epoch 9/10 15008/15008 [==============================] - 61s 4ms/step - loss: 0.6304 - nll_loss: 0.5539 - entropy_loss: 0.0766 - val_loss: 0.6627 - val_nll_loss: 0.5984 - val_entropy_loss: 0.0644 Epoch 10/10 15008/15008 [==============================] - 59s 4ms/step - loss: 0.6097 - nll_loss: 0.5368 - entropy_loss: 0.0728 - val_loss: 0.6740 - val_nll_loss: 0.6147 - val_entropy_loss: 0.0593
#Save scrambler checkpoint
save_dir = 'saved_models'
model_name = 'optimus5_inclusion_scrambler_bits_0125_epochs_10'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.save_model(model_path)
pickle.dump({'train_history' : train_history}, open(save_dir + '/' + model_name + '_train_history.pickle', 'wb'))
Saved scrambler model at saved_models/optimus5_inclusion_scrambler_bits_0125_epochs_10.h5
#Load models
save_dir = 'saved_models'
model_name = 'optimus5_inclusion_scrambler_bits_0125_epochs_10'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.load_model(model_path)
Loaded scrambler model from saved_models/optimus5_inclusion_scrambler_bits_0125_epochs_10.h5
#Plot training statistics
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(2 * 4, 3))
n_epochs_actual = len(train_history['nll_loss'])
ax1.plot(np.arange(1, n_epochs_actual + 1), train_history['nll_loss'], linewidth=3, color='green')
ax1.plot(np.arange(1, n_epochs_actual + 1), train_history['val_nll_loss'], linewidth=3, color='orange')
plt.sca(ax1)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("NLL", fontsize=14)
plt.xlim(1, n_epochs_actual)
plt.xticks([1, n_epochs_actual], [1, n_epochs_actual], fontsize=12)
plt.yticks(fontsize=12)
ax2.plot(np.arange(1, n_epochs_actual + 1), train_history['entropy_loss'], linewidth=3, color='green')
ax2.plot(np.arange(1, n_epochs_actual + 1), train_history['val_entropy_loss'], linewidth=3, color='orange')
plt.sca(ax2)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("Entropy Loss", fontsize=14)
plt.xlim(1, n_epochs_actual)
plt.xticks([1, n_epochs_actual], [1, n_epochs_actual], fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
plt.show()
#Interpret the test set using the trained scrambler
pwm_test, sample_test, importance_scores_test = scrambler.interpret(x_test)
3200/3200 [==============================] - 2s 491us/step
#Visualize a few reconstructed sequence patterns
plot_examples = np.arange(5).tolist()
save_examples = []
importance_scores_test *= sequence_mask[None, None, :, None]
for test_ix in plot_examples :
print("Test sequence " + str(test_ix) + ":")
y_test_hat_ref = predictor.predict(x=[x_test[test_ix:test_ix+1, ...]], batch_size=1)[0, 0]
y_test_hat = predictor.predict(x=[sample_test[test_ix, ...]], batch_size=32)[:32, 0].tolist()
print(" - Prediction (original) = " + str(round(y_test_hat_ref, 2))[:4])
print(" - Predictions (scrambled) = " + str([float(str(round(y_test_hat[i], 2))[:4]) for i in range(len(y_test_hat))]))
save_figs = False
if save_examples is not None and test_ix in save_examples :
save_figs = True
plot_dna_logo(x_test[test_ix, 0, :, :], sequence_template=sequence_template, figsize=(10, 1), plot_start=0, plot_end=50, plot_sequence_template=True, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_orig_sequence")
plot_dna_logo(pwm_test[test_ix, 0, :, :], sequence_template=sequence_template, figsize=(10, 1), plot_start=0, plot_end=50, plot_sequence_template=True, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_scrambld_pwm")
plot_dna_importance_scores(importance_scores_test[test_ix, 0, :, :].T, encoder.decode(x_test[test_ix, 0, :, :]), figsize=(10, 1), score_clip=None, sequence_template=sequence_template, plot_start=0, plot_end=50, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_scores")
Test sequence 0: - Prediction (original) = -1.0 - Predictions (scrambled) = [-0.7, -0.9, -1.0, -1.6, -1.2, -1.3, -0.9, -1.5, -0.5, -1.4, -1.0, -1.2, -0.3, -0.7, -1.1, -1.2, -1.3, -0.7, -0.9, -1.1, -0.7, -1.1, -1.9, -1.0, -1.3, -1.1, -1.1, -0.5, -0.9, -0.7, -0.8, -1.6]
Test sequence 1: - Prediction (original) = -1.0 - Predictions (scrambled) = [-1.1, -1.2, -1.3, -1.4, -0.5, -1.1, -0.9, -1.3, -1.0, -1.2, -1.1, -0.1, -1.0, -0.3, -1.0, -1.1, -1.2, -0.6, -1.1, -1.0, -0.4, -0.9, -0.1, -1.1, -1.3, -1.2, -1.1, -0.7, -1.0, -1.0, -1.3, -0.5]
Test sequence 2: - Prediction (original) = -0.9 - Predictions (scrambled) = [-0.6, -0.0, -0.3, 0.08, -0.5, -0.7, -1.0, -1.5, -0.3, -0.8, -1.1, -0.9, -0.4, -0.9, -0.8, -0.3, -0.9, -0.4, -1.1, -0.7, -0.4, -1.1, -0.3, -0.3, -0.7, -0.4, -0.7, 0.36, -0.6, -0.8, -1.3, -0.0]
Test sequence 3: - Prediction (original) = -1.1 - Predictions (scrambled) = [-0.2, 0.0, -0.0, -0.5, -0.4, -0.3, -0.8, 0.36, -0.2, 0.25, -0.6, -0.3, -0.7, -0.9, 0.63, -0.1, -0.2, 0.68, -0.6, -0.5, -1.2, 0.81, -0.1, 0.69, -0.8, 0.93, -1.0, -0.5, -1.0, -0.6, -0.2, -0.4]
Test sequence 4: - Prediction (original) = -1.2 - Predictions (scrambled) = [-0.1, -0.5, -0.5, -0.5, -0.7, -0.8, -0.5, -0.8, -1.7, -0.9, -1.0, -1.0, -1.1, -1.4, -1.3, -0.8, -1.8, -1.4, -0.8, -1.1, -1.4, -0.5, -1.5, -1.0, -0.3, 0.14, -0.8, -1.7, -0.3, -1.2, -1.1, -0.6]