#!/usr/bin/env python # coding: utf-8 # In[1]: # Import modules and set options get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np import pymc3 as pm from itertools import product sns.set_context("notebook") import theano.tensor as tt from theano import shared from sklearn import metrics DATA_DIR = '../data/clean/' SEEDS = 20090425, 19700903 # Import bBAL data # In[2]: bbal_pathogens = pd.read_csv(DATA_DIR + 'bbal_pathogens.csv', index_col=0) bbal_pathogens.head() # Aggregate bBAL data # In[3]: bbal_agg = (bbal_pathogens.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max()) bbal_agg.head(10) # Data by patient, day and pathogen # In[4]: pcr_path_lookup = {1: 'Staphylococcus aureus', 2: 'Streptococcus pneumonia', 3: 'Streptococcus Group B', 4: 'Acetinobacter baumannii', 5: 'Pseudomonas aeruginosa', 6: 'Haemophilus influenza', 7: 'Klebsiella pneumoniae', 8: 'Escherichia coli', 9: 'Enterobacter cloacae', 10: 'Stenotrophomonas maltophilia', 11: 'Enterobacter aerogenes', 12: 'Serratia marcescens', 13: 'Klebsiella oxytoca', 14: 'Proteus mirabilis', 15: 'Enterococcus faecalis', 16: 'Enterococcus faecium', 17: 'Candida albicans', 18: 'Other'} # In[5]: def fill_pathogens(x, labels, lookup, fill_with=1): return (x.dropna() .drop_duplicates(subset=['pathogen']) .set_index('pathogen') .reindex(list(lookup.values())) .reset_index() .assign(record_id=labels[0], day=labels[1]) .fillna(fill_with)) # In[6]: bbal_groups = [] for labels, group in bbal_pathogens.groupby(['record_id', 'day']): recid, day = labels group_full = fill_pathogens(group, labels, pcr_path_lookup, fill_with=0) bbal_groups.append(group_full) # In[7]: bbal_complete = pd.concat(bbal_groups).reset_index(drop=True) # In[8]: bbal_agg = (bbal_complete.groupby(['record_id', 'day', 'pathogen'])[['cfu_count']].max()) bbal_agg.head(10) # Import HME data # In[9]: hme_pathogens = pd.read_csv(DATA_DIR + 'hme_pathogens.csv', index_col=0) hme_pathogens.head() # Aggregate HME data # In[10]: hme_groups = [] for labels, group in hme_pathogens.groupby(['record_id', 'day']): recid, day = labels group_full = fill_pathogens(group, labels, pcr_path_lookup) hme_groups.append(group_full) # In[11]: hme_complete = pd.concat(hme_groups).reset_index(drop=True) # In[12]: hme_complete.head() # In[13]: hme_agg = (hme_complete.assign(log_count=np.log10(hme_complete.pcr_count).round(1)) .drop('pcr_count', axis=1) .dropna() .groupby(['record_id', 'day', 'pathogen'])[['log_count']].max()) hme_agg.head() # In[14]: bbal_hme = (bbal_agg.join(hme_agg).fillna(0) .reset_index() .rename(columns={'log_count':'HME count', 'cfu_count':'bBAL count'})) # In[15]: bbal_hme.head() # In[16]: pathogen_list = bbal_hme.pathogen.unique() n_pathogens = pathogen_list.shape[0] pathogen_encode = dict(zip(pathogen_list, range(n_pathogens))) pathogen_decode = dict(zip(range(n_pathogens), pathogen_list)) bbal_hme['pathogen_id'] = bbal_hme.pathogen.replace(pathogen_encode) # In[17]: n_pathogens # ## Predict negative bBAL from HME # In[18]: bbal_hme['bbal_negative'] = bbal_hme['bBAL count'] < 3 # In[19]: bbal_hme.head() # In[20]: bbal_hme.shape # In[21]: x_path = bbal_hme.pathogen_id.values x_hme = bbal_hme['HME count'].values y = bbal_hme.bbal_negative.astype(int).values with pm.Model() as neg_pred_model: μ = pm.Normal('μ', 0, sd=10) σ = pm.HalfCauchy('σ', 2.5) θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens) θ = pm.Deterministic('θ', μ + σ * θ_tilde) β = pm.Normal('β', 0, sd=10) π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme)) pm.Bernoulli('likeihood', π, observed=y) # In[22]: with neg_pred_model: neg_trace = pm.sample(3000, tune=2000, step_args = {'target_accept': 0.99}, njobs=2) # In[23]: invlogit = lambda x: 1 / (1 + np.exp(-x)) # In[24]: pm.forestplot(neg_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit, xtitle='Probability of negative bBAL with no HME detected', vline=-1) # In[25]: pm.traceplot(neg_trace, varnames=['β', 'σ']); # In[43]: def plot_pathogen(pathogen, trace, direction): fig, ax = plt.subplots() x_hme_pred = np.arange(0, 10) pathogen_ind = np.where(pathogen_list==pathogen)[0][0] p_pred = np.array([invlogit(trace['θ'][:, pathogen_ind] + trace['β']*x) for x in x_hme_pred]) low, med, high = np.percentile(p_pred, [2.5, 50., 97.5], axis=1) ax.plot(x_hme_pred, med) ax.fill_between(x_hme_pred, low, high, color='k', alpha=0.35, zorder=5, label='95% posterior credible interval') y,x = bbal_hme.loc[bbal_hme.pathogen==pathogen, ['bbal_{}'.format(direction), 'HME count']].astype(int).values.T ax.plot(x, y + np.random.randn(len(y))*0.01, 'bo') ax.set_xlabel('Log-count of pathogen in HME'); ax.set_ylabel('Probability of {} bBAL'.format(direction)) ax.set_ylim(-0.1,1.1) ax.set_title(pathogen); # In[44]: plot_pathogen('Haemophilus influenza', neg_trace, 'negative') # In[45]: plot_pathogen('Streptococcus pneumonia', neg_trace, 'negative') # Here is a plot for a rarer species. Note the uncertainty: # In[46]: plot_pathogen('Proteus mirabilis', neg_trace, 'negative') # ## Predict positive bBAL from positive HME # In[30]: bbal_hme['bbal_positive'] = bbal_hme['bBAL count'] >= 3 # In[31]: x_path = bbal_hme.pathogen_id.values x_hme = bbal_hme['HME count'].values y = bbal_hme.bbal_positive.astype(int).values with pm.Model() as pos_pred_model: μ = pm.Normal('μ', 0, sd=10) σ = pm.HalfCauchy('σ', 2.5) θ_tilde = pm.Normal('θ_tilde', mu=0, sd=1, shape=n_pathogens) θ = pm.Deterministic('θ', μ + σ * θ_tilde) # ϕ = pm.Normal('ϕ', 0, sd=10) # τ = pm.HalfCauchy('τ', 2.5) # β_tilde = pm.Normal('β_tilde', mu=0, sd=1, shape=n_pathogens) # β = pm.Deterministic('β', ϕ + τ * β_tilde) β = pm.Normal('β', 0, sd=10) π_exp = pm.Deterministic('π_exp', pm.math.invlogit(θ + β)) π = pm.Deterministic('π', pm.math.invlogit(θ[x_path] + β*x_hme)) pm.Bernoulli('likeihood', π, observed=y) # In[32]: with pos_pred_model: pos_trace = pm.sample(3000, tune=2000, init='ADVI', njobs=2, random_seeds=SEEDS) # In[33]: pm.forestplot(pos_trace, varnames=['θ'], ylabels=pathogen_list, transform=invlogit, xtitle='Probability of positive bBAL with no HME', vline=0) # In[34]: pm.traceplot(pos_trace, varnames=['β']) # In[48]: plot_pathogen('Haemophilus influenza', pos_trace, 'positive') # In[49]: plot_pathogen('Streptococcus pneumonia', pos_trace, 'positive')