# Import modules and set options
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pymc3 as pm
from itertools import product
sns.set_context("notebook")
import theano.tensor as tt
from theano import shared
from sklearn import metrics
DATA_DIR = '../data/clean/'
SEEDS = 20090425, 19700903
Import bBAL data
bbal_pathogens = pd.read_csv(DATA_DIR + 'bbal_pathogens.csv', index_col=0)
bbal_pathogens.head()
record_id | day | pathogen | cfu_count | |
---|---|---|---|---|
1 | 1012 | 1.0 | Staphylococcus aureus | 6.0 |
4 | 1041 | 1.0 | Escherichia coli | 6.0 |
5 | 1043 | 1.0 | Haemophilus influenza | 6.0 |
6 | 1050 | 1.0 | Haemophilus influenza | 6.0 |
7 | 1060 | 1.0 | Staphylococcus aureus | 6.0 |
Aggregate bBAL data
bbal_agg = (bbal_pathogens.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max())
bbal_agg.head(10)
cfu_count | |||
---|---|---|---|
record_id | day | pathogen | |
1001 | 10.0 | Other | 2.0 |
1003 | 5.0 | Klebsiella pneumoniae | 3.0 |
Other | 2.0 | ||
1007 | 3.0 | Enterobacter cloacae | 2.0 |
Haemophilus influenza | 6.0 | ||
Other | 2.0 | ||
Streptococcus pneumonia | 6.0 | ||
1011 | 4.0 | Staphylococcus aureus | 6.0 |
1012 | 1.0 | Staphylococcus aureus | 6.0 |
1015 | 9.0 | Other | 2.0 |
Data by patient, day and pathogen
pcr_path_lookup = {1: 'Staphylococcus aureus',
2: 'Streptococcus pneumonia',
3: 'Streptococcus Group B',
4: 'Acetinobacter baumannii',
5: 'Pseudomonas aeruginosa',
6: 'Haemophilus influenza',
7: 'Klebsiella pneumoniae',
8: 'Escherichia coli',
9: 'Enterobacter cloacae',
10: 'Stenotrophomonas maltophilia',
11: 'Enterobacter aerogenes',
12: 'Serratia marcescens',
13: 'Klebsiella oxytoca',
14: 'Proteus mirabilis',
15: 'Enterococcus faecalis',
16: 'Enterococcus faecium',
17: 'Candida albicans',
18: 'Other'}
def fill_pathogens(x, labels, lookup, fill_with=1):
return (x.dropna()
.drop_duplicates(subset=['pathogen'])
.set_index('pathogen')
.reindex(list(lookup.values()))
.reset_index()
.assign(record_id=labels[0], day=labels[1])
.fillna(fill_with))
bbal_groups = []
for labels, group in bbal_pathogens.groupby(['record_id', 'day']):
recid, day = labels
group_full = fill_pathogens(group, labels, pcr_path_lookup, fill_with=0)
bbal_groups.append(group_full)
bbal_complete = pd.concat(bbal_groups).reset_index(drop=True)
bbal_agg = (bbal_complete.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max())
bbal_agg.head(10)
cfu_count | |||
---|---|---|---|
record_id | day | pathogen | |
1001 | 10.0 | Acetinobacter baumannii | 0.0 |
Candida albicans | 0.0 | ||
Enterobacter aerogenes | 0.0 | ||
Enterobacter cloacae | 0.0 | ||
Enterococcus faecalis | 0.0 | ||
Enterococcus faecium | 0.0 | ||
Escherichia coli | 0.0 | ||
Haemophilus influenza | 0.0 | ||
Klebsiella oxytoca | 0.0 | ||
Klebsiella pneumoniae | 0.0 |
Import HME data
hme_pathogens = pd.read_csv(DATA_DIR + 'hme_pathogens.csv', index_col=0)
hme_pathogens.head()
record_id | day | pathogen | pcr_count | |
---|---|---|---|---|
3 | 1009 | 1.0 | Staphylococcus aureus | 66200000.0 |
4 | 1012 | 1.0 | Staphylococcus aureus | 34100.0 |
17 | 1027 | 1.0 | Acetinobacter baumannii | 671000.0 |
21 | 1031 | 1.0 | Serratia marcescens | 15800000.0 |
32 | 1045 | 1.0 | Streptococcus pneumonia | 897000.0 |
Aggregate HME data
hme_groups = []
for labels, group in hme_pathogens.groupby(['record_id', 'day']):
recid, day = labels
group_full = fill_pathogens(group, labels, pcr_path_lookup)
hme_groups.append(group_full)
hme_complete = pd.concat(hme_groups).reset_index(drop=True)
hme_complete.head()
pathogen | record_id | day | pcr_count | |
---|---|---|---|---|
0 | Staphylococcus aureus | 1002 | 5.0 | 37900.0 |
1 | Streptococcus pneumonia | 1002 | 5.0 | 1.0 |
2 | Streptococcus Group B | 1002 | 5.0 | 1.0 |
3 | Acetinobacter baumannii | 1002 | 5.0 | 1.0 |
4 | Pseudomonas aeruginosa | 1002 | 5.0 | 1.0 |
hme_agg = (hme_complete.assign(log_count=np.log10(hme_complete.pcr_count).round(1))
.drop('pcr_count', axis=1)
.dropna()
.groupby(['record_id', 'day', 'pathogen'])[['log_count']].max())
hme_agg.head()
log_count | |||
---|---|---|---|
record_id | day | pathogen | |
1002 | 5.0 | Acetinobacter baumannii | 0.0 |
Candida albicans | 0.0 | ||
Enterobacter aerogenes | 7.4 | ||
Enterobacter cloacae | 0.0 | ||
Enterococcus faecalis | 0.0 |
bbal_hme = (bbal_agg.join(hme_agg).fillna(0)
.reset_index()
.rename(columns={'log_count':'HME count',
'cfu_count':'bBAL count'}))
bbal_hme.head()
record_id | day | pathogen | bBAL count | HME count | |
---|---|---|---|---|---|
0 | 1001 | 10.0 | Acetinobacter baumannii | 0.0 | 0.0 |
1 | 1001 | 10.0 | Candida albicans | 0.0 | 0.0 |
2 | 1001 | 10.0 | Enterobacter aerogenes | 0.0 | 0.0 |
3 | 1001 | 10.0 | Enterobacter cloacae | 0.0 | 0.0 |
4 | 1001 | 10.0 | Enterococcus faecalis | 0.0 | 0.0 |
pathogen_list = bbal_hme.pathogen.unique()
n_pathogens = pathogen_list.shape[0]
pathogen_encode = dict(zip(pathogen_list, range(n_pathogens)))
pathogen_decode = dict(zip(range(n_pathogens), pathogen_list))
bbal_hme['pathogen_id'] = bbal_hme.pathogen.replace(pathogen_encode)
n_pathogens
18
bbal_hme['bbal_negative'] = bbal_hme['bBAL count'] < 3
bbal_hme.head()
record_id | day | pathogen | bBAL count | HME count | pathogen_id | bbal_negative | |
---|---|---|---|---|---|---|---|
0 | 1001 | 10.0 | Acetinobacter baumannii | 0.0 | 0.0 | 0 | True |
1 | 1001 | 10.0 | Candida albicans | 0.0 | 0.0 | 1 | True |
2 | 1001 | 10.0 | Enterobacter aerogenes | 0.0 | 0.0 | 2 | True |
3 | 1001 | 10.0 | Enterobacter cloacae | 0.0 | 0.0 | 3 | True |
4 | 1001 | 10.0 | Enterococcus faecalis | 0.0 | 0.0 | 4 | True |
bbal_hme.shape
(666, 7)
x_path = bbal_hme.pathogen_id.values
x_hme = bbal_hme['HME count'].values
y = bbal_hme.bbal_negative.astype(int).values
with pm.Model() as neg_pred_model:
μ = pm.Normal('μ', 0, sd=10)
σ = pm.HalfCauchy('σ', 2.5)
θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens)
θ = pm.Deterministic('θ', μ + σ * θ_tilde)
β = pm.Normal('β', 0, sd=10)
π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme))
pm.Bernoulli('likeihood', π, observed=y)
with neg_pred_model:
neg_trace = pm.sample(3000, tune=2000,
step_args = {'target_accept': 0.99},
njobs=2)
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 96%|█████████▋| 4820/5000 [00:19<00:00, 271.47it/s]/Users/fonnescj/Repos/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 3 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 100%|██████████| 5000/5000 [00:20<00:00, 246.62it/s]
invlogit = lambda x: 1 / (1 + np.exp(-x))
pm.forestplot(neg_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit,
xtitle='Probability of negative bBAL with no HME detected', vline=-1)
<matplotlib.gridspec.GridSpec at 0x12ec57978>
pm.traceplot(neg_trace, varnames=['β', 'σ']);
def plot_pathogen(pathogen, trace, direction):
fig, ax = plt.subplots()
x_hme_pred = np.arange(0, 10)
pathogen_ind = np.where(pathogen_list==pathogen)[0][0]
p_pred = np.array([invlogit(trace['θ'][:, pathogen_ind] + trace['β']*x) for x in x_hme_pred])
low, med, high = np.percentile(p_pred, [2.5, 50., 97.5], axis=1)
ax.plot(x_hme_pred, med)
ax.fill_between(x_hme_pred, low, high,
color='k', alpha=0.35, zorder=5,
label='95% posterior credible interval')
y,x = bbal_hme.loc[bbal_hme.pathogen==pathogen, ['bbal_{}'.format(direction), 'HME count']].astype(int).values.T
ax.plot(x, y + np.random.randn(len(y))*0.01, 'bo')
ax.set_xlabel('Log-count of pathogen in HME');
ax.set_ylabel('Probability of {} bBAL'.format(direction))
ax.set_ylim(-0.1,1.1)
ax.set_title(pathogen);
plot_pathogen('Haemophilus influenza', neg_trace, 'negative')
plot_pathogen('Streptococcus pneumonia', neg_trace, 'negative')
Here is a plot for a rarer species. Note the uncertainty:
plot_pathogen('Proteus mirabilis', neg_trace, 'negative')
bbal_hme['bbal_positive'] = bbal_hme['bBAL count'] >= 3
x_path = bbal_hme.pathogen_id.values
x_hme = bbal_hme['HME count'].values
y = bbal_hme.bbal_positive.astype(int).values
with pm.Model() as pos_pred_model:
μ = pm.Normal('μ', 0, sd=10)
σ = pm.HalfCauchy('σ', 2.5)
θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens)
θ = pm.Deterministic('θ', μ + σ * θ_tilde)
# ϕ = pm.Normal('ϕ', 0, sd=10)
# τ = pm.HalfCauchy('τ', 2.5)
# β_tilde = pm.Normal('β_tilde', mu=0, sd=1, shape=n_pathogens)
# β = pm.Deterministic('β', ϕ + τ * β_tilde)
β = pm.Normal('β', 0, sd=10)
π_exp = pm.Deterministic('π_exp', pm.math.invlogit(θ + β))
π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme))
pm.Bernoulli('likeihood', π, observed=y)
with pos_pred_model:
pos_trace = pm.sample(3000, tune=2000, init='ADVI',
njobs=2,
random_seeds=SEEDS)
Auto-assigning NUTS sampler... Initializing NUTS using advi... Average Loss = 147.5: 8%|▊ | 16617/200000 [00:04<00:36, 4986.68it/s] Convergence archived at 16800 Interrupted at 16,800 [8%]: Average Loss = 212.81 100%|█████████▉| 4985/5000 [00:21<00:00, 294.94it/s]/Users/fonnescj/Repos/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 1 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 100%|██████████| 5000/5000 [00:21<00:00, 233.36it/s]
pm.forestplot(pos_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit,
xtitle='Probability of positive bBAL with no HME', vline=0)
<matplotlib.gridspec.GridSpec at 0x140beb978>
pm.traceplot(pos_trace, varnames=['β'])
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x12ec74c18>, <matplotlib.axes._subplots.AxesSubplot object at 0x12ea5e6a0>]], dtype=object)
plot_pathogen('Haemophilus influenza', pos_trace, 'positive')
plot_pathogen('Streptococcus pneumonia', pos_trace, 'positive')