In this tutorial we illustrate input feature importance attribution and variant effect prediction.
import os
import numpy as np
import h5py
from keras import backend as K
from keras.layers import Conv2D
from keras.layers import GlobalAveragePooling2D
from pkg_resources import resource_filename
from janggu import Janggu
from janggu import Scorer
from janggu import inputlayer
from janggu import outputdense
from janggu.data import Bioseq
from janggu.data import Cover
from janggu.data import GenomicIndexer
from janggu.data import ReduceDim
from janggu.data import plotGenomeTrack
from janggu.data import LineTrack
from janggu.data import SeqTrack
from janggu.layers import DnaConv2D
from janggu import input_attribution
np.random.seed(1234)
/home/wkopp/anaconda3/envs/jdev/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters Using TensorFlow backend.
First, we need to specify the output directory in which the results are stored and load the datasets. We also specify the number of epochs to train the model and the sequence feature order.
order = 3
epochs = 100
os.environ['JANGGU_OUTPUT'] = '/home/wkopp/janggu_examples'
# load the dataset
# The pseudo genome represents just a concatenation of all sequences
# in sample.fa and sample2.fa. Therefore, the results should be almost
# identically to the models obtained from classify_fasta.py.
REFGENOME = resource_filename('janggu', 'resources/pseudo_genome.fa')
VCFFILE = resource_filename('janggu', 'resources/pseudo_snps.vcf')
# ROI contains regions spanning positive and negative examples
ROI_TRAIN_FILE = resource_filename('janggu', 'resources/roi_train.bed')
ROI_TEST_FILE = resource_filename('janggu', 'resources/roi_test.bed')
# PEAK_FILE only contains positive examples
PEAK_FILE = resource_filename('janggu', 'resources/scores.bed')
# Training input and labels are purely defined genomic coordinates
DNA = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,
roi=ROI_TRAIN_FILE,
binsize=200,
store_whole_genome=True,
order=order)
LABELS = Cover.create_from_bed('peaks', roi=ROI_TRAIN_FILE,
bedfiles=PEAK_FILE,
binsize=200,
resolution=200)
DNA_TEST = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,
roi=ROI_TEST_FILE,
binsize=200,
store_whole_genome=True,
order=order)
LABELS_TEST = Cover.create_from_bed('peaks',
roi=ROI_TEST_FILE,
bedfiles=PEAK_FILE,
binsize=200,
resolution=200)
Define and fit a new model
@inputlayer
@outputdense('sigmoid')
def double_stranded_model_dnaconv(inputs, inp, oup, params):
""" keras model for scanning both DNA strands.
A more elegant way of scanning both strands for motif occurrences
is achieved by the DnaConv2D layer wrapper, which internally
performs the convolution operation with the normal kernel weights
and the reverse complemented weights.
"""
with inputs.use('dna') as layer:
# the name in inputs.use() should be the same as the dataset name.
layer = DnaConv2D(Conv2D(params[0], (params[1], 1),
activation=params[2]))(layer)
output = GlobalAveragePooling2D(name='motif')(layer)
return inputs, output
# create a new model object
model = Janggu.create(template=double_stranded_model_dnaconv,
modelparams=(30, 21, 'relu'),
inputs=DNA,
outputs=ReduceDim(LABELS))
model.compile(optimizer='adadelta', loss='binary_crossentropy',
metrics=['acc'])
hist = model.fit(DNA, ReduceDim(LABELS), epochs=epochs)
print('#' * 40)
print('loss: {}, acc: {}'.format(hist.history['loss'][-1],
hist.history['acc'][-1]))
print('#' * 40)
Generated model-id: '6acf30da11d48d4d96ed0669bf3a52f4' Epoch 1/100 244/244 [==============================] - 4s 14ms/step - loss: 0.6253 - acc: 0.6480 Epoch 2/100 244/244 [==============================] - 2s 10ms/step - loss: 0.5209 - acc: 0.7661 Epoch 3/100 244/244 [==============================] - 2s 10ms/step - loss: 0.4645 - acc: 0.7958 Epoch 4/100 244/244 [==============================] - 2s 10ms/step - loss: 0.4248 - acc: 0.8201 Epoch 5/100 244/244 [==============================] - 2s 10ms/step - loss: 0.3963 - acc: 0.8330 Epoch 6/100 244/244 [==============================] - 2s 10ms/step - loss: 0.3720 - acc: 0.8441 Epoch 7/100 244/244 [==============================] - 2s 10ms/step - loss: 0.3488 - acc: 0.8535 Epoch 8/100 244/244 [==============================] - 2s 9ms/step - loss: 0.3287 - acc: 0.8642 Epoch 9/100 244/244 [==============================] - 2s 10ms/step - loss: 0.3064 - acc: 0.8775 Epoch 10/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2868 - acc: 0.8881 Epoch 11/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2666 - acc: 0.8989 Epoch 12/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2482 - acc: 0.9084 Epoch 13/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2305 - acc: 0.9148 Epoch 14/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2159 - acc: 0.9233 Epoch 15/100 244/244 [==============================] - 2s 10ms/step - loss: 0.2011 - acc: 0.9297 Epoch 16/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1889 - acc: 0.9346 Epoch 17/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1777 - acc: 0.9396 Epoch 18/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1668 - acc: 0.9434 Epoch 19/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1571 - acc: 0.9465 Epoch 20/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1495 - acc: 0.9500 Epoch 21/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1412 - acc: 0.9549 Epoch 22/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1343 - acc: 0.9576 Epoch 23/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1277 - acc: 0.9589 Epoch 24/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1208 - acc: 0.9616 Epoch 25/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1161 - acc: 0.9632 Epoch 26/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1106 - acc: 0.9668 Epoch 27/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1066 - acc: 0.9661 Epoch 28/100 244/244 [==============================] - 2s 10ms/step - loss: 0.1021 - acc: 0.9689 Epoch 29/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0989 - acc: 0.9694 Epoch 30/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0948 - acc: 0.9721 Epoch 31/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0911 - acc: 0.9709 Epoch 32/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0873 - acc: 0.9752 Epoch 33/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0841 - acc: 0.9750 Epoch 34/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0811 - acc: 0.9766 Epoch 35/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0783 - acc: 0.9781 Epoch 36/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0759 - acc: 0.9789 Epoch 37/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0732 - acc: 0.9804 Epoch 38/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0704 - acc: 0.9809 Epoch 39/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0682 - acc: 0.9834 Epoch 40/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0655 - acc: 0.9839 Epoch 41/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0637 - acc: 0.9842 Epoch 42/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0622 - acc: 0.9841 Epoch 43/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0600 - acc: 0.9851 Epoch 44/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0579 - acc: 0.9869 Epoch 45/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0567 - acc: 0.9858 Epoch 46/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0547 - acc: 0.9862 Epoch 47/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0535 - acc: 0.9866 Epoch 48/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0520 - acc: 0.9874 Epoch 49/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0502 - acc: 0.9889 Epoch 50/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0490 - acc: 0.9883 Epoch 51/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0472 - acc: 0.9889 Epoch 52/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0462 - acc: 0.9892 Epoch 53/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0449 - acc: 0.9895 Epoch 54/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0436 - acc: 0.9914 Epoch 55/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0423 - acc: 0.9900 Epoch 56/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0407 - acc: 0.9908 Epoch 57/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0400 - acc: 0.9913 Epoch 58/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0388 - acc: 0.9912 Epoch 59/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0372 - acc: 0.9922 Epoch 60/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0368 - acc: 0.9926 Epoch 61/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0355 - acc: 0.9930 Epoch 62/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0347 - acc: 0.9937 Epoch 63/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0341 - acc: 0.9932 Epoch 64/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0329 - acc: 0.9940 Epoch 65/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0319 - acc: 0.9939 Epoch 66/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0311 - acc: 0.9950 Epoch 67/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0310 - acc: 0.9946 Epoch 68/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0293 - acc: 0.9951 Epoch 69/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0289 - acc: 0.9946 Epoch 70/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0281 - acc: 0.9954 Epoch 71/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0272 - acc: 0.9956 Epoch 72/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0267 - acc: 0.9959 Epoch 73/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0259 - acc: 0.9963 Epoch 74/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0250 - acc: 0.9965 Epoch 75/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0244 - acc: 0.9965 Epoch 76/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0245 - acc: 0.9967 Epoch 77/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0235 - acc: 0.9967 Epoch 78/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0226 - acc: 0.9974 Epoch 79/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0221 - acc: 0.9976 Epoch 80/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0214 - acc: 0.9980 Epoch 81/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0212 - acc: 0.9978 Epoch 82/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0203 - acc: 0.9974 Epoch 83/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0201 - acc: 0.9976 Epoch 84/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0194 - acc: 0.9981 Epoch 85/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0190 - acc: 0.9982 Epoch 86/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0184 - acc: 0.9986 Epoch 87/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0180 - acc: 0.9987 Epoch 88/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0177 - acc: 0.9983 Epoch 89/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0171 - acc: 0.9990 Epoch 90/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0165 - acc: 0.9990 Epoch 91/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0161 - acc: 0.9991 Epoch 92/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0161 - acc: 0.9987 Epoch 93/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0156 - acc: 0.9991 Epoch 94/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0151 - acc: 0.9995 Epoch 95/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0145 - acc: 0.9991 Epoch 96/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0140 - acc: 0.9991 Epoch 97/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0140 - acc: 0.9996 Epoch 98/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0136 - acc: 0.9995 Epoch 99/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0132 - acc: 0.9992 Epoch 100/100 244/244 [==============================] - 2s 10ms/step - loss: 0.0129 - acc: 0.9995 ######################################## loss: 0.012950192026199592, acc: 0.9994869821726305 ########################################
The toy example illustrates a binary classification example composed of Oct4 and Mafk binding sites. The true Oct4 binding sites have been labeled with ones while the Mafk binding sites are labeled zeros.
For a sanity check we inspect the predicted values for a few data points.
pred = model.predict(DNA_TEST)
cov_pred = Cover.create_from_array('BindingProba', pred, LABELS_TEST.gindexer)
print('Oct4 predictions scores should be greater than Mafk scores:')
print('Prediction score examples for Oct4')
for i in range(4):
print('{}.: {}'.format(i, cov_pred[i]))
print('Prediction score examples for Mafk')
for i in range(1, 5):
print('{}.: {}'.format(i, cov_pred[-i]))
Oct4 predictions scores should be greater than Mafk scores: Prediction score examples for Oct4 0.: [[[[0.99882954]]]] 1.: [[[[0.9025411]]]] 2.: [[[[0.99821955]]]] 3.: [[[[0.98289168]]]] Prediction score examples for Mafk 1.: [[[[0.59950411]]]] 2.: [[[[0.00201734]]]] 3.: [[[[0.00082711]]]] 4.: [[[[2.10625944e-06]]]]
In order to perform input feature attribution, we utilize the 'input_attribution' method for a genomic region of interest.
Underneath, the result is illustrated visually. It highlights an Oct4 binding sites occuring at the left peak.
# Extract the 4th interval to perform input feature importance attribution
# which represents an Oct4 bound region
gi = DNA.gindexer[3]
chrom = gi.chrom
start = gi.start
end = gi.end
attr_oct = input_attribution(model, DNA, chrom=chrom, start=start, end=end)
# visualize the important sequence features
plotGenomeTrack(SeqTrack(attr_oct[0]), chrom, start, end)
By comparison, the input attribution for a Mafk binding sites highlights a Mafk motif in the center.
# For the comparison, extract an interval
# representing a Mafk bound region and visualize the
# important features.
gi = DNA.gindexer[7796]
chrom = gi.chrom
start = gi.start
end = gi.end
attr_mafk = input_attribution(model, DNA, chrom=chrom, start=start, end=end)
plotGenomeTrack(SeqTrack(attr_mafk[0]), chrom, start, end)
In order to perform variant effect prediction, we need the DNA sequence loaded for the whole genome into a Bioseq object and a VCF file containing single nucleotide variant.
The result of this analysis is stored in two files: scores.hdf5 and snps.bed.gz.
# output directory for the variant effect prediction
vcfoutput = os.path.join(os.environ['JANGGU_OUTPUT'], 'vcfoutput')
os.makedirs(vcfoutput, exist_ok=True)
# perform variant effect prediction using Bioseq object and
# a VCF file
scoresfile, variantsfile = model.predict_variant_effect(DNA,
VCFFILE,
conditions=['feature'],
output_folder=vcfoutput)
scoresfile = os.path.join(vcfoutput, 'scores.hdf5')
variantsfile = os.path.join(vcfoutput, 'snps.bed.gz')
scores.hdf5 contains a variety of scores for each variant. The most important ones are refscore and altscore which are used to derive the score difference and the logoddsscore.
# parse the variant effect predictions (difference between
# reference and alternative variant) into a Cover object
# for the purpose of visualization
f = h5py.File(scoresfile, 'r')
for name in f:
print(name)
altscore diffscore labels logoddsscore refscore
Finally, we can convert the variant predictions (the score differences in this case) along with the genomic context with other genomic tracks.
gindexer = GenomicIndexer.create_from_file(variantsfile, None, None)
snpcov = Cover.create_from_array('snps', f['diffscore'],
gindexer,
store_whole_genome=True,
padding_value=np.nan)
#snpcov = Cover.create_from_array('snps', f['diffscore'],
# gindexer,
# store_whole_genome=False,
# padding_value=np.nan)
gi = DNA.gindexer[3]
chrom = gi.chrom
start = gi.start
end = gi.end
plotGenomeTrack([LineTrack(snpcov,
linestyle="None"), SeqTrack(attr_oct[0])],
chrom, start, end)
The score difference shows a dip around the site that is indicated as most important from the input attribution as well.
It is also possible to export the variant effect predictions as bigwig for further explorations in e.g. IGV.
To this end, use the export_to_bigwig
method
os.makedirs('./snps', exist_ok=True)
snpcov.export_to_bigwig('./snps')