import keras
from keras.models import Sequential, Model, load_model
import os
import pickle
import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.io as spio
import matplotlib.pyplot as plt
from scrambler.models import *
from scrambler.utils import OneHotEncoder, get_sequence_masks
from scrambler.visualizations import plot_dna_logo, plot_dna_importance_scores
from apa_utils import load_apa_data, load_apa_predictor_cleavage_logodds, animate_apa_examples
Using TensorFlow backend.
#Load APA data and predictor
encoder = OneHotEncoder(seq_length=205, channel_map={'A' : 0, 'C' : 1, 'G' : 2, 'T' : 3})
data_path = 'apa_doubledope_cached_set.csv'
x_train, y_train, x_test, y_test = load_apa_data(data_path, encoder)
predictor_path = 'saved_models/aparent_plasmid_iso_cut_distalpas_all_libs_no_sampleweights_sgd.h5'
predictor = load_apa_predictor_cleavage_logodds(predictor_path)
len(data_df) = 34748 (loaded) x_train.shape = (32992, 1, 205, 4) x_test.shape = (1728, 1, 205, 4) y_train.shape = (32992, 1) y_test.shape = (1728, 1)
#Print predictor input/output details
print("predictor.inputs = " + str(predictor.inputs))
print("predictor.outputs = " + str(predictor.outputs))
predictor.inputs = [<tf.Tensor 'seq_input:0' shape=(?, 205, 4) dtype=float32>, <tf.Tensor 'lib_input:0' shape=(?, 13) dtype=float32>, <tf.Tensor 'distal_pas_input:0' shape=(?, 1) dtype=float32>] predictor.outputs = [<tf.Tensor 'rel_cut_logodds/Log:0' shape=(?, 4) dtype=float32>]
#Define sequence template and background
sequence_template = 'CTTCCGATCT$$$$$$$$$$$$$$$$$$$$CATTACTCGCATCCA$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$CAGCCAATTAAGCC$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$CTAC'
pseudo_count = 1.0
onehot_template = encoder(sequence_template)[None, ...]
sequence_mask = get_sequence_masks([sequence_template])[0]
x_mean = (np.sum(x_train, axis=(0, 1)) + pseudo_count) / (x_train.shape[0] + 4. * pseudo_count)
#Visualize background sequence distribution
plot_dna_logo(np.copy(x_mean), sequence_template=sequence_template, figsize=(14, 0.65), logo_height=1.0, plot_start=0, plot_end=205)
#Calculate mean training set kl-divergence against background
x_train_clipped = np.clip(np.copy(x_train[:, 0, :, :]), 1e-8, 1. - 1e-8)
kl_divs = np.sum(x_train_clipped * np.log(x_train_clipped / np.tile(np.expand_dims(x_mean, axis=0), (x_train_clipped.shape[0], 1, 1))), axis=-1) / np.log(2.0)
x_mean_kl_divs = np.sum(kl_divs * sequence_mask, axis=-1) / np.sum(sequence_mask)
x_mean_kl_div = np.mean(x_mean_kl_divs)
print("Mean KL Div against background (bits) = " + str(x_mean_kl_div))
Mean KL Div against background (bits) = 1.8729476983107292
#For the sake of the example, lets transform x to a 1d shape
x_train = x_train[:, 0, ...]
x_test = x_test[:, 0, ...]
print(x_train.shape)
print(x_test.shape)
(32992, 205, 4) (1728, 205, 4)
#Create extra inputs that the predictor model expects
feat_1_train = np.zeros((x_train.shape[0], 13))
feat_1_test = np.zeros((x_test.shape[0], 13))
feat_1_train[:, 4] = 1.
feat_1_test[:, 4] = 1.
feat_2_train = np.ones((x_train.shape[0], 1))
feat_2_test = np.ones((x_test.shape[0], 1))
print(feat_1_train.shape)
print(feat_2_train.shape)
(32992, 13) (32992, 1)
#Scrambler network configuration
network_config = {
'n_groups' : 5,
'n_resblocks_per_group' : 4,
'n_channels' : 32,
'window_size' : 3,
'dilation_rates' : [1, 2, 4, 2, 1],
'drop_rate' : 0.0,
'norm_mode' : 'instance',
'mask_smoothing' : True,
'mask_smoothing_window_size' : 5,
'mask_smoothing_std' : 1.,
'mask_drop_scales' : [1, 5],
'mask_min_drop_rate' : 0.0,
'mask_max_drop_rate' : 0.5,
'label_input' : False
}
#Train scrambler(s) to maximize cleavage logodds at different positions
save_dir = 'saved_models'
for cut_pos in [0, 1, 2, 3] :
print("Training scrambler for cleavage position = " + str(cut_pos) + ".")
#Initialize scrambler
scrambler = Scrambler(
scrambler_mode='inclusion',
input_size_x=None,
input_size_y=205,
n_out_channels=4,
n_classes=4,
input_templates=[onehot_template],
input_backgrounds=[x_mean],
batch_size=32,
n_samples=32,
sample_mode='gumbel',
zeropad_input=False,
mask_dropout=False,
network_config=network_config
)
#y_pred_scrambled.shape = (batch_size, n_samples, n_classes)
def maximize_cleavage_logodds(y_pred_non_scrambled, y_pred_scrambled, cut_pos=cut_pos) :
return -K.mean(y_pred_scrambled[..., cut_pos], axis=-1)
n_epochs = 5
_ = scrambler.train(
predictor,
x_train,
y_train,
x_test,
y_test,
n_epochs,
extra_input_train=[feat_1_train, feat_2_train],
extra_input_test=[feat_1_test, feat_2_test],
monitor_test_indices=None,
custom_loss_func=maximize_cleavage_logodds,
entropy_mode='target',
entropy_bits=0.1,
entropy_weight=20.
)
#Save scrambler checkpoint
model_name = 'apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_' + str(cut_pos)
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.save_model(model_path)
Training scrambler for cleavage position = 0. Train on 32992 samples, validate on 1728 samples Epoch 1/5 32992/32992 [==============================] - 180s 5ms/step - loss: 3.4686 - nll_loss: 3.3073 - entropy_loss: 0.1613 - val_loss: 3.2598 - val_nll_loss: 3.1952 - val_entropy_loss: 0.0646 Epoch 2/5 32992/32992 [==============================] - 163s 5ms/step - loss: 3.2132 - nll_loss: 3.1495 - entropy_loss: 0.0636 - val_loss: 3.1753 - val_nll_loss: 3.0945 - val_entropy_loss: 0.0808 Epoch 3/5 32992/32992 [==============================] - 163s 5ms/step - loss: 3.1444 - nll_loss: 3.0763 - entropy_loss: 0.0681 - val_loss: 3.1162 - val_nll_loss: 3.0327 - val_entropy_loss: 0.0835 Epoch 4/5 32992/32992 [==============================] - 172s 5ms/step - loss: 3.0962 - nll_loss: 3.0251 - entropy_loss: 0.0710 - val_loss: 3.0844 - val_nll_loss: 3.0077 - val_entropy_loss: 0.0767 Epoch 5/5 32992/32992 [==============================] - 162s 5ms/step - loss: 3.0751 - nll_loss: 3.0051 - entropy_loss: 0.0700 - val_loss: 3.0701 - val_nll_loss: 2.9821 - val_entropy_loss: 0.0880 Saved scrambler model at saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_0.h5 Training scrambler for cleavage position = 1. Train on 32992 samples, validate on 1728 samples Epoch 1/5 32992/32992 [==============================] - 186s 6ms/step - loss: 1.4403 - nll_loss: 1.4013 - entropy_loss: 0.0390 - val_loss: 1.4186 - val_nll_loss: 1.3926 - val_entropy_loss: 0.0261 Epoch 2/5 32992/32992 [==============================] - 164s 5ms/step - loss: 1.4065 - nll_loss: 1.3860 - entropy_loss: 0.0204 - val_loss: 1.3974 - val_nll_loss: 1.3818 - val_entropy_loss: 0.0156 Epoch 3/5 32992/32992 [==============================] - 164s 5ms/step - loss: 1.3897 - nll_loss: 1.3760 - entropy_loss: 0.0137 - val_loss: 1.3815 - val_nll_loss: 1.3690 - val_entropy_loss: 0.0125 Epoch 4/5 32992/32992 [==============================] - 164s 5ms/step - loss: 1.3767 - nll_loss: 1.3612 - entropy_loss: 0.0155 - val_loss: 1.3690 - val_nll_loss: 1.3515 - val_entropy_loss: 0.0175 Epoch 5/5 32992/32992 [==============================] - 164s 5ms/step - loss: 1.3444 - nll_loss: 1.3233 - entropy_loss: 0.0211 - val_loss: 1.3080 - val_nll_loss: 1.2800 - val_entropy_loss: 0.0280 Saved scrambler model at saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_1.h5 Training scrambler for cleavage position = 2. Train on 32992 samples, validate on 1728 samples Epoch 1/5 32992/32992 [==============================] - 196s 6ms/step - loss: 3.6210 - nll_loss: 3.4691 - entropy_loss: 0.1519 - val_loss: 3.3546 - val_nll_loss: 3.2851 - val_entropy_loss: 0.0696 Epoch 2/5 32992/32992 [==============================] - 166s 5ms/step - loss: 3.2659 - nll_loss: 3.2138 - entropy_loss: 0.0521 - val_loss: 3.1926 - val_nll_loss: 3.1233 - val_entropy_loss: 0.0692 Epoch 3/5 32992/32992 [==============================] - 178s 5ms/step - loss: 3.1582 - nll_loss: 3.0994 - entropy_loss: 0.0588 - val_loss: 3.1292 - val_nll_loss: 3.0730 - val_entropy_loss: 0.0562 Epoch 4/5 32992/32992 [==============================] - 165s 5ms/step - loss: 3.1201 - nll_loss: 3.0595 - entropy_loss: 0.0607 - val_loss: 3.1054 - val_nll_loss: 3.0619 - val_entropy_loss: 0.0436 Epoch 5/5 32992/32992 [==============================] - 165s 5ms/step - loss: 3.1010 - nll_loss: 3.0396 - entropy_loss: 0.0614 - val_loss: 3.0931 - val_nll_loss: 3.0397 - val_entropy_loss: 0.0533 Saved scrambler model at saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_2.h5 Training scrambler for cleavage position = 3. Train on 32992 samples, validate on 1728 samples Epoch 1/5 32992/32992 [==============================] - 199s 6ms/step - loss: 4.2247 - nll_loss: 4.1326 - entropy_loss: 0.0921 - val_loss: 4.0715 - val_nll_loss: 3.9446 - val_entropy_loss: 0.1269 Epoch 2/5 32992/32992 [==============================] - 167s 5ms/step - loss: 3.9482 - nll_loss: 3.8660 - entropy_loss: 0.0822 - val_loss: 3.8639 - val_nll_loss: 3.7782 - val_entropy_loss: 0.0856 Epoch 3/5 32992/32992 [==============================] - 177s 5ms/step - loss: 3.8204 - nll_loss: 3.7241 - entropy_loss: 0.0964 - val_loss: 3.7941 - val_nll_loss: 3.6414 - val_entropy_loss: 0.1527 Epoch 4/5 32992/32992 [==============================] - 166s 5ms/step - loss: 3.7544 - nll_loss: 3.6487 - entropy_loss: 0.1058 - val_loss: 3.7285 - val_nll_loss: 3.6036 - val_entropy_loss: 0.1249 Epoch 5/5 32992/32992 [==============================] - 166s 5ms/step - loss: 3.7147 - nll_loss: 3.6035 - entropy_loss: 0.1112 - val_loss: 3.7091 - val_nll_loss: 3.6212 - val_entropy_loss: 0.0879 Saved scrambler model at saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_3.h5
#Load models and interpret test patterns for all cleavage positions
save_dir = 'saved_models'
pwm_test = []
sample_test = []
importance_scores_test = []
for cut_pos in [0, 1, 2, 3] :
print("Interpreting for cleavage position = " + str(cut_pos) + ".")
model_name = 'apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_' + str(cut_pos)
model_path = os.path.join(save_dir, model_name + '.h5')
scrambler.load_model(model_path)
#Interpret the test set using the trained scrambler
pwm_t, sample_t, importance_scores_t = scrambler.interpret(x_test)
pwm_test.append(pwm_t[None, ...])
sample_test.append(sample_t[None, ...])
importance_scores_test.append(importance_scores_t[None, ...])
pwm_test = np.concatenate(pwm_test, axis=0)
sample_test = np.concatenate(sample_test, axis=0)
importance_scores_test = np.concatenate(importance_scores_test, axis=0)
Interpreting for cleavage position = 0. Loaded scrambler model from saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_0.h5 1728/1728 [==============================] - 3s 2ms/step Interpreting for cleavage position = 1. Loaded scrambler model from saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_1.h5 1728/1728 [==============================] - 1s 800us/step Interpreting for cleavage position = 2. Loaded scrambler model from saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_2.h5 1728/1728 [==============================] - 1s 805us/step Interpreting for cleavage position = 3. Loaded scrambler model from saved_models/apa_inclusion_scrambler_smooth_target_bits_01_epochs_10_deeper_cut_pos_3.h5 1728/1728 [==============================] - 1s 795us/step
#Visualize a few reconstructed sequence patterns
importance_scores_test *= sequence_mask[None, None, :, None]
plot_examples = [3, 5, 6]
save_examples = []
cuts = [76 + 5, 76 + 15, 76 + 25, 76 + 35]
for test_ix in plot_examples :
print("Test sequence " + str(test_ix) + ":")
save_figs = False
if save_examples is not None and test_ix in save_examples :
save_figs = True
plot_dna_logo(x_test[test_ix, :, :], sequence_template=sequence_template, figsize=(14, 0.65), plot_start=0, plot_end=205, plot_sequence_template=True, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_orig_sequence")
#Plot interpretation PWM for each cleavage position
for cut_ix, cut_pos in enumerate([0, 1, 2, 3]) :
#Mark the position where we are maximizing cleavage
cut_template = 'N' * 205
cut_template = cut_template[:cuts[cut_ix]] + 'CCC' + cut_template[cuts[cut_ix]+1:]
plot_dna_logo(np.zeros((205, 4)), sequence_template=cut_template, figsize=(14, 0.65), plot_start=0, plot_end=205, plot_sequence_template=True, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_scrambld_pwm")
#Plot Scrambler interpretation (PWM)
plot_dna_logo(pwm_test[cut_ix, test_ix, :, :], sequence_template=sequence_template, figsize=(14, 0.65), plot_start=0, plot_end=205, plot_sequence_template=True, save_figs=save_figs, fig_name=model_name + "_test_ix_" + str(test_ix) + "_scrambld_pwm")
Test sequence 3:
Test sequence 5:
Test sequence 6: