In [1]:
# Import modules and set options
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pymc3 as pm
from itertools import product

sns.set_context("notebook")

import theano.tensor as tt
from theano import shared
from sklearn import metrics

DATA_DIR = '../data/clean/'
SEEDS = 20090425, 19700903

Import bBAL data

In [2]:
bbal_pathogens = pd.read_csv(DATA_DIR + 'bbal_pathogens.csv', index_col=0)
bbal_pathogens.head()
Out[2]:
record_id day pathogen cfu_count
1 1012 1.0 Staphylococcus aureus 6.0
4 1041 1.0 Escherichia coli 6.0
5 1043 1.0 Haemophilus influenza 6.0
6 1050 1.0 Haemophilus influenza 6.0
7 1060 1.0 Staphylococcus aureus 6.0

Aggregate bBAL data

In [3]:
bbal_agg = (bbal_pathogens.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max())
bbal_agg.head(10)
Out[3]:
cfu_count
record_id day pathogen
1001 10.0 Other 2.0
1003 5.0 Klebsiella pneumoniae 3.0
Other 2.0
1007 3.0 Enterobacter cloacae 2.0
Haemophilus influenza 6.0
Other 2.0
Streptococcus pneumonia 6.0
1011 4.0 Staphylococcus aureus 6.0
1012 1.0 Staphylococcus aureus 6.0
1015 9.0 Other 2.0

Data by patient, day and pathogen

In [4]:
pcr_path_lookup = {1: 'Staphylococcus aureus',
2: 'Streptococcus pneumonia',
3: 'Streptococcus Group B',
4: 'Acetinobacter baumannii',
5: 'Pseudomonas aeruginosa',
6: 'Haemophilus influenza',
7: 'Klebsiella pneumoniae',
8: 'Escherichia coli',
9: 'Enterobacter cloacae',
10: 'Stenotrophomonas maltophilia',
11: 'Enterobacter aerogenes',
12: 'Serratia marcescens',
13: 'Klebsiella oxytoca',
14: 'Proteus mirabilis',
15: 'Enterococcus faecalis',
16: 'Enterococcus faecium',
17: 'Candida albicans',
18: 'Other'}
In [5]:
def fill_pathogens(x, labels, lookup, fill_with=1):
    return (x.dropna()
          .drop_duplicates(subset=['pathogen'])
          .set_index('pathogen')
          .reindex(list(lookup.values()))
          .reset_index()
          .assign(record_id=labels[0], day=labels[1])
          .fillna(fill_with))
In [6]:
bbal_groups = []
for labels, group in bbal_pathogens.groupby(['record_id', 'day']):
    recid, day = labels
    group_full = fill_pathogens(group, labels, pcr_path_lookup, fill_with=0)
    bbal_groups.append(group_full)
In [7]:
bbal_complete = pd.concat(bbal_groups).reset_index(drop=True)
In [8]:
bbal_agg = (bbal_complete.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max())
bbal_agg.head(10)
Out[8]:
cfu_count
record_id day pathogen
1001 10.0 Acetinobacter baumannii 0.0
Candida albicans 0.0
Enterobacter aerogenes 0.0
Enterobacter cloacae 0.0
Enterococcus faecalis 0.0
Enterococcus faecium 0.0
Escherichia coli 0.0
Haemophilus influenza 0.0
Klebsiella oxytoca 0.0
Klebsiella pneumoniae 0.0

Import HME data

In [9]:
hme_pathogens = pd.read_csv(DATA_DIR + 'hme_pathogens.csv', index_col=0)
hme_pathogens.head()
Out[9]:
record_id day pathogen pcr_count
3 1009 1.0 Staphylococcus aureus 66200000.0
4 1012 1.0 Staphylococcus aureus 34100.0
17 1027 1.0 Acetinobacter baumannii 671000.0
21 1031 1.0 Serratia marcescens 15800000.0
32 1045 1.0 Streptococcus pneumonia 897000.0

Aggregate HME data

In [10]:
hme_groups = []
for labels, group in hme_pathogens.groupby(['record_id', 'day']):
    recid, day = labels
    group_full = fill_pathogens(group, labels, pcr_path_lookup)
    hme_groups.append(group_full)
In [11]:
hme_complete = pd.concat(hme_groups).reset_index(drop=True)
In [12]:
hme_complete.head()
Out[12]:
pathogen record_id day pcr_count
0 Staphylococcus aureus 1002 5.0 37900.0
1 Streptococcus pneumonia 1002 5.0 1.0
2 Streptococcus Group B 1002 5.0 1.0
3 Acetinobacter baumannii 1002 5.0 1.0
4 Pseudomonas aeruginosa 1002 5.0 1.0
In [13]:
hme_agg = (hme_complete.assign(log_count=np.log10(hme_complete.pcr_count).round(1))
             .drop('pcr_count', axis=1)
             .dropna()
             .groupby(['record_id', 'day', 'pathogen'])[['log_count']].max())
hme_agg.head()
Out[13]:
log_count
record_id day pathogen
1002 5.0 Acetinobacter baumannii 0.0
Candida albicans 0.0
Enterobacter aerogenes 7.4
Enterobacter cloacae 0.0
Enterococcus faecalis 0.0
In [14]:
bbal_hme = (bbal_agg.join(hme_agg).fillna(0)
                    .reset_index()
                    .rename(columns={'log_count':'HME count',
                                     'cfu_count':'bBAL count'}))
In [15]:
bbal_hme.head()
Out[15]:
record_id day pathogen bBAL count HME count
0 1001 10.0 Acetinobacter baumannii 0.0 0.0
1 1001 10.0 Candida albicans 0.0 0.0
2 1001 10.0 Enterobacter aerogenes 0.0 0.0
3 1001 10.0 Enterobacter cloacae 0.0 0.0
4 1001 10.0 Enterococcus faecalis 0.0 0.0
In [16]:
pathogen_list = bbal_hme.pathogen.unique()
n_pathogens = pathogen_list.shape[0]

pathogen_encode = dict(zip(pathogen_list, range(n_pathogens)))
pathogen_decode = dict(zip(range(n_pathogens), pathogen_list))

bbal_hme['pathogen_id'] = bbal_hme.pathogen.replace(pathogen_encode)
In [17]:
n_pathogens
Out[17]:
18

Predict negative bBAL from HME

In [18]:
bbal_hme['bbal_negative'] = bbal_hme['bBAL count'] < 3
In [19]:
bbal_hme.head()
Out[19]:
record_id day pathogen bBAL count HME count pathogen_id bbal_negative
0 1001 10.0 Acetinobacter baumannii 0.0 0.0 0 True
1 1001 10.0 Candida albicans 0.0 0.0 1 True
2 1001 10.0 Enterobacter aerogenes 0.0 0.0 2 True
3 1001 10.0 Enterobacter cloacae 0.0 0.0 3 True
4 1001 10.0 Enterococcus faecalis 0.0 0.0 4 True
In [20]:
bbal_hme.shape
Out[20]:
(666, 7)
In [21]:
x_path = bbal_hme.pathogen_id.values
x_hme = bbal_hme['HME count'].values
y = bbal_hme.bbal_negative.astype(int).values

with pm.Model() as neg_pred_model:
    
    μ = pm.Normal('μ', 0, sd=10)
    σ = pm.HalfCauchy('σ', 2.5)
    θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens)
    θ = pm.Deterministic('θ', μ + σ * θ_tilde)
    
    β = pm.Normal('β', 0, sd=10)
    π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme))
    
    pm.Bernoulli('likeihood', π, observed=y)
In [22]:
with neg_pred_model:
    
    neg_trace = pm.sample(3000, tune=2000, 
                          step_args = {'target_accept': 0.99}, 
                          njobs=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
 96%|█████████▋| 4820/5000 [00:19<00:00, 271.47it/s]/Users/fonnescj/Repos/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 3 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize.
  % (self._chain_id, n_diverging))
100%|██████████| 5000/5000 [00:20<00:00, 246.62it/s]
In [23]:
invlogit = lambda x: 1 / (1 + np.exp(-x))
In [24]:
pm.forestplot(neg_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit,
                   xtitle='Probability of negative bBAL with no HME detected', vline=-1)
Out[24]:
<matplotlib.gridspec.GridSpec at 0x12ec57978>
In [25]:
pm.traceplot(neg_trace, varnames=['β', 'σ']);
In [43]:
def plot_pathogen(pathogen, trace, direction):

    fig, ax = plt.subplots()

    x_hme_pred = np.arange(0, 10)
    
    pathogen_ind = np.where(pathogen_list==pathogen)[0][0]

    p_pred = np.array([invlogit(trace['θ'][:, pathogen_ind] + trace['β']*x) for x in x_hme_pred])

    low, med, high = np.percentile(p_pred, [2.5, 50., 97.5], axis=1)

    ax.plot(x_hme_pred, med)
    ax.fill_between(x_hme_pred, low, high,
                    color='k', alpha=0.35, zorder=5,
                    label='95% posterior credible interval')
    
    y,x = bbal_hme.loc[bbal_hme.pathogen==pathogen, ['bbal_{}'.format(direction), 'HME count']].astype(int).values.T
    ax.plot(x, y + np.random.randn(len(y))*0.01, 'bo')
    
    ax.set_xlabel('Log-count of pathogen in HME');

    ax.set_ylabel('Probability of {} bBAL'.format(direction))
    ax.set_ylim(-0.1,1.1)

    ax.set_title(pathogen);
In [44]:
plot_pathogen('Haemophilus influenza', neg_trace, 'negative')
In [45]:
plot_pathogen('Streptococcus pneumonia', neg_trace, 'negative')

Here is a plot for a rarer species. Note the uncertainty:

In [46]:
plot_pathogen('Proteus mirabilis', neg_trace, 'negative')

Predict positive bBAL from positive HME

In [30]:
bbal_hme['bbal_positive'] = bbal_hme['bBAL count'] >= 3
In [31]:
x_path = bbal_hme.pathogen_id.values
x_hme = bbal_hme['HME count'].values
y = bbal_hme.bbal_positive.astype(int).values

with pm.Model() as pos_pred_model:
    
    μ = pm.Normal('μ', 0, sd=10)
    σ = pm.HalfCauchy('σ', 2.5)
    θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens)
    θ = pm.Deterministic('θ', μ + σ * θ_tilde)
    
    
#     ϕ = pm.Normal('ϕ', 0, sd=10)
#     τ = pm.HalfCauchy('τ', 2.5)
#     β_tilde = pm.Normal('β_tilde', mu=0, sd=1, shape=n_pathogens)
#     β = pm.Deterministic('β', ϕ + τ * β_tilde)

    β = pm.Normal('β', 0, sd=10)
    
    π_exp = pm.Deterministic('π_exp', pm.math.invlogit(θ + β))
    
    π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme))
    
    pm.Bernoulli('likeihood', π, observed=y)
In [32]:
with pos_pred_model:
    pos_trace = pm.sample(3000, tune=2000, init='ADVI',
                          njobs=2,
                          random_seeds=SEEDS)
Auto-assigning NUTS sampler...
Initializing NUTS using advi...
Average Loss = 147.5:   8%|▊         | 16617/200000 [00:04<00:36, 4986.68it/s] 
Convergence archived at 16800
Interrupted at 16,800 [8%]: Average Loss = 212.81
100%|█████████▉| 4985/5000 [00:21<00:00, 294.94it/s]/Users/fonnescj/Repos/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 1 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize.
  % (self._chain_id, n_diverging))
100%|██████████| 5000/5000 [00:21<00:00, 233.36it/s]
In [33]:
pm.forestplot(pos_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit,
                   xtitle='Probability of positive bBAL with no HME', vline=0)
Out[33]:
<matplotlib.gridspec.GridSpec at 0x140beb978>
In [34]:
pm.traceplot(pos_trace, varnames=['β'])
Out[34]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x12ec74c18>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x12ea5e6a0>]], dtype=object)
In [48]:
plot_pathogen('Haemophilus influenza', pos_trace, 'positive')
In [49]:
plot_pathogen('Streptococcus pneumonia', pos_trace, 'positive')