import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras import regularizers
from keras import backend as K
import keras.losses
import tensorflow as tf
from tensorflow.python.framework import ops
import isolearn.keras as iso
import numpy as np
import tensorflow as tf
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
import pandas as pd
import os
import pickle
import numpy as np
import scipy.sparse as sp
import scipy.io as spio
import matplotlib.pyplot as plt
import isolearn.io as isoio
import isolearn.keras as isol
from genesis.visualization import *
from genesis.generator import *
from genesis.predictor import *
from genesis.optimizer import *
import sklearn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import pearsonr
import seaborn as sns
from matplotlib import colors
import editdistance
def subselect_list(li, ixs) :
return [
li[ixs[k]] for k in range(len(ixs))
]
class IdentityEncoder(iso.SequenceEncoder) :
def __init__(self, seq_len, channel_map) :
super(IdentityEncoder, self).__init__('identity', (seq_len, len(channel_map)))
self.seq_len = seq_len
self.n_channels = len(channel_map)
self.encode_map = channel_map
self.decode_map = {
nt: ix for ix, nt in self.encode_map.items()
}
def encode(self, seq) :
encoding = np.zeros((self.seq_len, self.n_channels))
for i in range(len(seq)) :
if seq[i] in self.encode_map :
channel_ix = self.encode_map[seq[i]]
encoding[i, channel_ix] = 1.
return encoding
def encode_inplace(self, seq, encoding) :
for i in range(len(seq)) :
if seq[i] in self.encode_map :
channel_ix = self.encode_map[seq[i]]
encoding[i, channel_ix] = 1.
def encode_inplace_sparse(self, seq, encoding_mat, row_index) :
raise NotImplementError()
def decode(self, encoding) :
seq = ''
for pos in range(0, encoding.shape[0]) :
argmax_nt = np.argmax(encoding[pos, :])
max_nt = np.max(encoding[pos, :])
seq += self.decode_map[argmax_nt]
return seq
def decode_sparse(self, encoding_mat, row_index) :
raise NotImplementError()
#Plot joint histograms
def plot_joint_histo(measurements, labels, x_label, y_label, colors=None, n_bins=50, figsize=(6, 4), legend_outside=False, save_fig=False, fig_name="default_1", fig_dpi=150, min_val=None, max_val=None, max_y_val=None) :
min_hist_val = np.min(measurements[0])
max_hist_val = np.max(measurements[0])
for i in range(1, len(measurements)) :
min_hist_val = min(min_hist_val, np.min(measurements[i]))
max_hist_val = max(max_hist_val, np.max(measurements[i]))
if min_val is not None :
min_hist_val = min_val
if max_val is not None :
max_hist_val = max_val
hists = []
bin_edges = []
means = []
for i in range(len(measurements)) :
hist, b_edges = np.histogram(measurements[i], range=(min_hist_val, max_hist_val), bins=n_bins, density=True)
hists.append(hist)
bin_edges.append(b_edges)
means.append(np.mean(measurements[i]))
bin_width = bin_edges[0][1] - bin_edges[0][0]
f = plt.figure(figsize=figsize)
for i in range(len(measurements)) :
if colors is not None :
plt.bar(bin_edges[i][1:] - bin_width/2., hists[i], width=bin_width, linewidth=2, edgecolor='black', color=colors[i], label=labels[i])
else :
plt.bar(bin_edges[i][1:] - bin_width/2., hists[i], width=bin_width, linewidth=2, edgecolor='black', label=labels[i])
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlim(min_hist_val, max_hist_val)
if max_y_val is not None :
plt.ylim(0, max_y_val)
plt.xlabel(x_label, fontsize=14)
plt.ylabel(y_label, fontsize=14)
if colors is not None :
for i in range(len(measurements)) :
plt.axvline(x=means[i], linewidth=2, color=colors[i], linestyle="--")
if not legend_outside :
plt.legend(fontsize=14, loc='upper left')
else :
plt.legend(fontsize=14, bbox_to_anchor=(1.04,1), loc="upper left")
plt.tight_layout()
if save_fig :
plt.savefig(fig_name + ".eps")
plt.savefig(fig_name + ".svg")
plt.savefig(fig_name + ".png", dpi=fig_dpi, transparent=True)
plt.show()
#Plot join histograms
def plot_joint_cmp(measurements, labels, y_label, plot_type='violin', colors=None, figsize=(6, 4), legend_outside=False, save_fig=False, fig_name="default_1", fig_dpi=150, min_y_val=None, max_y_val=None, violin_bw=None, violin_cut=None) :
f = plt.figure(figsize=figsize)
sns_g = None
if colors is not None :
if plot_type == 'violin' :
if violin_bw is None :
if violin_cut is None :
sns_g = sns.violinplot(data=measurements, palette=colors, scale='width')
else :
sns_g = sns.violinplot(data=measurements, palette=colors, scale='width', cut=violin_cut)
else :
if violin_cut is None :
sns_g = sns.violinplot(data=measurements, palette=colors, scale='width', bw=violin_bw)
else :
sns_g = sns.violinplot(data=measurements, palette=colors, scale='width', bw=violin_bw, cut=violin_cut)
elif plot_type == 'strip' :
sns_g = sns.stripplot(data=measurements, palette=colors, alpha=0.1, jitter=0.3, linewidth=2, edgecolor='black') #, x=labels
for i in range(len(measurements)) :
plt.plot(x=[i, i+1], y=[np.median(measurements[i]), np.median(measurements[i])], linewidth=2, color=colors[i], linestyle="--")
elif plot_type == 'bar' :
for i in range(len(measurements)) :
plt.bar([i], [np.percentile(measurements[i], 100)], width=0.4, color=colors[i], label=str(i) + ") " + labels[i], linewidth=2, edgecolor='black')
plt.bar([i+0.2], [np.percentile(measurements[i], 95)], width=0.4, color=colors[i], linewidth=2, edgecolor='black')
plt.bar([i+0.4], [np.percentile(measurements[i], 80)], width=0.4, color=colors[i], linewidth=2, edgecolor='black')
plt.bar([i+0.6], [np.percentile(measurements[i], 50)], width=0.4, color=colors[i], linewidth=2, edgecolor='black')
else :
if plot_type == 'violin' :
if violin_bw is None :
if violin_cut is None :
sns_g = sns.violinplot(data=measurements, scale='width')
else :
sns_g = sns.violinplot(data=measurements, scale='width', cut=violin_cut)
else :
if violin_cut is None :
sns_g = sns.violinplot(data=measurements, scale='width', bw=violin_bw)
else :
sns_g = sns.violinplot(data=measurements, scale='width', bw=violin_bw, cut=violin_cut)
elif plot_type == 'strip' :
sns_g = sns.stripplot(data=measurements, alpha=0.1, jitter=0.3, linewidth=2, edgecolor='black') #, x=labels
elif plot_type == 'bar' :
for i in range(len(measurements)) :
plt.bar([i], [np.percentile(measurements[i], 100)], width=0.25, label=str(i) + ") " + labels[i], linewidth=2, edgecolor='black')
plt.bar([i+0.125], [np.percentile(measurements[i], 95)], width=0.25, linewidth=2, edgecolor='black')
plt.bar([i+0.25], [np.percentile(measurements[i], 80)], width=0.25, linewidth=2, edgecolor='black')
plt.bar([i+0.375], [np.percentile(measurements[i], 50)], width=0.25, linewidth=2, edgecolor='black')
plt.xticks(np.arange(len(labels)), fontsize=14)
plt.yticks(fontsize=14)
#plt.xlim(min_hist_val, max_hist_val)
if min_y_val is not None and max_y_val is not None :
plt.ylim(min_y_val, max_y_val)
plt.ylabel(y_label, fontsize=14)
if plot_type not in ['violin', 'strip'] :
if not legend_outside :
plt.legend(fontsize=14, loc='upper left')
else :
plt.legend(fontsize=14, bbox_to_anchor=(1.04,1), loc="upper left")
else :
if not legend_outside :
f.get_axes()[0].legend(fontsize=14, loc="upper left", labels=[str(label_i) + ") " + label for label_i, label in enumerate(labels)])
else :
f.get_axes()[0].legend(fontsize=14, bbox_to_anchor=(1.04,1), loc="upper left", labels=[str(label_i) + ") " + label for label_i, label in enumerate(labels)])
plt.tight_layout()
if save_fig :
plt.savefig(fig_name + ".eps")
plt.savefig(fig_name + ".svg")
plt.savefig(fig_name + ".png", dpi=fig_dpi, transparent=True)
plt.show()
#Load generated data from models to be evaluated
def load_sequences(file_path, split_on_tab=True, seq_template=None, max_n_sequences=1e6, select_best_fitness=False, predictor=None, batch_size=32) :
seqs = []
with open(file_path, "rt") as f :
for l in f.readlines() :
l_strip = l.strip()
seq = l_strip
if split_on_tab :
seq = l_strip.split("\t")[0]
if seq_template is not None :
seq = ''.join([
seq_template[j] if seq_template[j] != 'N' else seq[j]
for j in range(len(seq))
])
seqs.append(seq)
if select_best_fitness and predictor is not None :
onehots = np.expand_dims(np.concatenate([
np.expand_dims(acgt_encoder.encode(seq), axis=0) for seq in seqs
], axis=0), axis=1)
#Predict fitness
score_pred = predictor.predict(x=[onehots], batch_size=batch_size)
score_pred = np.ravel(score_pred[:, 0])
sort_index = np.argsort(score_pred)[::-1]
seqs = [
seqs[sort_index[i]] for i in range(len(seqs))
]
return seqs[:max_n_sequences]
#Metric helper functions
def compute_edit_distance(onehots, opt_len=100) :
shuffle_index = np.arange(onehots.shape[0])
shuffle_index = shuffle_index[::-1]#np.random.shuffle(shuffle_index)
seqs = [acgt_encoder.decode(onehots[i, :, :, 0]) for i in range(onehots.shape[0])]
seqs_shuffled = [seqs[shuffle_index[i]] for i in range(onehots.shape[0])]
edit_distances = np.ravel([float(editdistance.eval(seq_1, seq_2)) for seq_1, seq_2 in zip(seqs, seqs_shuffled)])
edit_distances /= opt_len
mean_edit_distance = np.mean(edit_distances)
return edit_distances, mean_edit_distance
#Evaluate metrics for each model
def compute_metrics(seqs, n_seqs_to_test=960, batch_size=64, opt_len=90) :
n_seqs_to_test = min(len(seqs), n_seqs_to_test)
onehots = np.expand_dims(np.concatenate([
np.expand_dims(acgt_encoder.encode(seq), axis=0) for seq in seqs
], axis=0), axis=-1)
#Predict fitness
score_pred = saved_predictor.predict(x=[np.moveaxis(onehots[:n_seqs_to_test], 3, 1)], batch_size=batch_size)
score_pred = np.ravel(score_pred[:, 0])
#Compare pair-wise edit distances
edit_dists, _ = compute_edit_distance(onehots[:n_seqs_to_test], opt_len=opt_len)
return score_pred, edit_dists
from keras.backend.tensorflow_backend import set_session
def contain_tf_gpu_mem_usage() :
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)
contain_tf_gpu_mem_usage()