#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import pandas as pd import seaborn as sns import matplotlib.pylab as plt sns.set_context('notebook') seed_numbers = 20090425, 19700903 # Import current dataset # In[2]: data_file = 'data/AGE Cases and Healthy Controls Merged with Lab Results Year 2 to 4 v05.23.2017.csv' rotavirus_data = (pd.read_csv(data_file, low_memory=False, na_values=[88, 99, 777, 888, 999]) .rename(columns={'rotacdc':'rotavirus', 'sapo':'sapovirus', 'astro':'astrovirus', 'noro':'norovirus'})) # In[3]: (rotavirus_data.houtcome==0).sum() # In[4]: rotavirus_data['hicu'].sum() # In[5]: ct_cols = ['sapo_ct', 'astro_ct', 'noro_ct'] # In[6]: virus_cols = ['rotavirus','sapovirus','astrovirus','norovirus'] # Specify columns needed to calculate score, and restrict to cases # In[7]: rotavirus_data['virus_pos'] = rotavirus_data[['rotavirus', 'sapovirus', 'astrovirus', 'norovirus']].sum(1) > 0 # In[8]: severity_cols = ['caseid', 'virus_pos', 'daysill', 'diarrh', 'diarrhstart', 'diarrhdays', 'diarrhepiso', 'vomit', 'vomitstart', 'vomitdays', 'vomitepiso', 'fever', 'feverstart', 'feverhigh', 'feverdays', 'feverdecimal', 'measureby', 'measurothsp', 'newvisit', 'irtherapy', 'irtherapydur', 'settingnew', 'provider', 'outadm', 'inpelig', 'surday', 'oralrehydra', 'behave', 'eyes', 'takefluids', 'skintest'] severity_data = rotavirus_data.loc[rotavirus_data.case==1, severity_cols + virus_cols + ct_cols] # In[9]: severity_data['coinfection'] = (severity_data[virus_cols].sum(axis=1) > 1).astype(int) # In[10]: severity_data['care_level'] = None severity_data.loc[severity_data.provider==1, 'care_level'] = 1 severity_data.loc[severity_data.provider==2, 'care_level'] = 3 severity_data.loc[severity_data.provider==3, 'care_level'] = 4 severity_data.loc[severity_data.provider.isin([2,3]) & (severity_data.outadm==1) & (severity_data.inpelig==1) & (severity_data.surday==1), 'care_level'] = 2 # In[11]: severity_data.care_level.value_counts() # Treatment variable # In[12]: severity_data['treatment'] = 0 severity_data.loc[severity_data.care_level.isin([1,2]), 'treatment'] = 2 severity_data.loc[severity_data.care_level.isin([3,4]) & (severity_data.oralrehydra==1) & (severity_data.irtherapy==1), 'treatment'] = 1 # In[13]: severity_data['fever_high_c'] = severity_data.feverhigh.apply(lambda f: (f - 32) * 5 / 9) # Dehydration severity # In[14]: severe = ((severity_data.eyes==2).astype(int) + (severity_data.behave>3).astype(int) + (severity_data.skintest==3).astype(int) + (severity_data.takefluids==2).astype(int)) >= 2 moderate = ((severity_data.eyes==2).astype(int) + (severity_data.behave>2).astype(int) + (severity_data.skintest>=2).astype(int) + (severity_data.takefluids>=1).astype(int)) >= 2 # In[15]: dehydration = severe.astype(int) + (moderate.astype(int)*2) dehydration.value_counts() # Visualization of score components # In[16]: (severity_data.groupby('virus_pos')[['diarrhdays', 'diarrhepiso', 'vomitdays', 'vomitepiso', 'fever_high_c']].median()) # In[17]: severity_data[['diarrhdays', 'diarrhepiso', 'vomitdays', 'vomitepiso', 'fever_high_c']].describe() # In[18]: severity_data.diarrhdays.hist() plt.vlines([4,6], ymin=0, ymax=700) plt.xlabel('Diarrhea days') # In[19]: severity_data.diarrhepiso.hist() plt.vlines([3,6], ymin=0, ymax=1800) plt.xlim(0, 30) plt.xlabel('Diarrhea episodes') # In[20]: severity_data.vomitdays.hist() plt.vlines([1,3], ymin=0, ymax=1400) plt.xlim(0, 10) plt.xlabel('Vomit days') # In[21]: severity_data.vomitepiso.hist(bins=20) plt.vlines([1,5], ymin=0, ymax=2000) plt.xlim(0, 30) plt.xlabel('Vomit episodes') # In[22]: severity_data.fever_high_c.hist(bins=10) plt.vlines([38.4, 39], ymin=0, ymax=500) plt.xlabel('Max. Temperature') # Functions for scoring each component # In[23]: def diarrh_duration_points(x): if (x==0) or np.isnan(x): return 0 elif 1 <= x <= 4: return 1 elif 4 < x <= 5: return 2 elif x >= 6: return 3 else: raise ValueError('Invalid diarrhea duration value:', x) def diarrh_episode_points(x): if (x==0) or np.isnan(x): return 0 elif 1 <= x < 4: return 1 elif 4 <= x <= 5: return 2 elif x > 5: return 3 else: raise ValueError('Invalid stool value:', x) def vomit_duration_points(x): if (x==0) or np.isnan(x): return 0 elif x == 1: return 1 elif x == 2: return 2 elif x > 2: return 3 else: raise ValueError('Invalid vomit duration value:', x) def vomit_episode_points(x): if (x==0) or np.isnan(x): return 0 elif x==1: return 1 elif 2 <= x <= 4: return 2 elif x > 4: return 3 else: raise ValueError('Invalid vomit episode value:', x) def fever_points(x): if (x < 37) or np.isnan(x): return 0 elif 37.1 <= x <= 38.4: return 1 elif 38.5 <= x <= 38.9: return 2 elif x >= 39: return 3 else: raise ValueError('Invalid fever value:', x) # Calculate points for each criterion # In[24]: diarrh_duration = (severity_data.assign(diarrhhours=severity_data.diarrhdays) .diarrhhours .apply(diarrh_duration_points)) diarrh_max_episodes = (severity_data.diarrhepiso.apply(diarrh_episode_points)) vomit_duration = (severity_data.assign(vomithours=severity_data.vomitdays) .vomithours .apply(vomit_duration_points)) vomit_max_episodes = (severity_data.vomitepiso.apply(vomit_episode_points)) fever_high = (severity_data.fever_high_c.apply(fever_points)) emergency = (severity_data.newvisit==1) * 3 # rehydration = (severity_data.irtherapydur==1).astype(int) hospitalization = 2*(severity_data.settingnew==1).astype(int) # Sum points to get score # In[25]: veskari_subset = (diarrh_duration + diarrh_max_episodes + vomit_duration + vomit_max_episodes + fever_high) # In[26]: veskari_score = (diarrh_duration + diarrh_max_episodes + vomit_duration + vomit_max_episodes + fever_high + emergency + dehydration + hospitalization) # In[27]: ax = sns.distplot(veskari_score, bins=veskari_score.max()-1, axlabel='MVS') text = "mean={}\nstd={}".format(np.round(veskari_score.mean(), 2), np.round(veskari_score.std(), 2)) ax.text(0, 0.12, text) # By comparison, here is the distribution in the Schnadower et al. paper that validates the MVS: # # ![](http://d.pr/i/3ISR94+) # In[28]: severity_data['mvs'] = veskari_score severity_data['mvs_symptoms'] = veskari_subset # In[29]: severity_data.head() # In[30]: severity_data.to_excel('results/severity.xlsx') # In[31]: severity_data_complete = severity_data.dropna(subset=virus_cols) # Drop those with no virus testing, and merge with severity score table # In[32]: severity_data_complete.shape # ## Scores by setting and virus # In[33]: settings = {1:'Inpatient', 2:'ED', 3:'Outpatient'} # In[34]: combinations = [(virus, setting+1) for virus in ('rotavirus', 'sapovirus', 'astrovirus', 'norovirus') for setting in range(3)] fig, axes = plt.subplots(4, 3, figsize=(12,8), sharex=True, sharey=True) plt.tight_layout() for ax, (virus, setting) in zip(axes.ravel(), combinations): sns.distplot(severity_data_complete.loc[(severity_data_complete.settingnew==setting) & severity_data[virus]==1, 'mvs'], ax=ax, bins=10, axlabel='{} {}'.format(settings[setting], virus)); # ## Virus model for severity # In[35]: from patsy import dmatrix import theano.tensor as tt from theano import shared # In[36]: formula = 'rotavirus + sapovirus + astrovirus + norovirus' formula += ' + rotavirus:sapovirus + rotavirus:astrovirus + rotavirus:norovirus' formula += ' + sapovirus:astrovirus + sapovirus:norovirus + astrovirus:norovirus' X = np.asarray(dmatrix(formula, severity_data_complete)) # In[37]: X_pred = np.array([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], [1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0], [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0], [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1]]) # In[38]: from pymc3 import Model, sample, Normal, HalfCauchy, Deterministic with Model() as severity_model: β = Normal('β', 0, sd=10, shape=X.shape[1]) σ = HalfCauchy('σ', 1) y = Normal('y', β.dot(shared(X.T)), sd=σ, observed=severity_data_complete.mvs.values) μ = Deterministic('μ', β.dot(shared(X_pred.T))) main_effects = Deterministic('main_effects', μ[:4]) contrasts = Deterministic('contrasts', μ[1:] - μ[0]) trace = sample(1000, tune=2000, chains=1, random_seed=seed_numbers[0]) # While severity tends to increase with rotavirus, sapovirus and astrovirus infection, coinfection tends to reduce severity relative to main effects (with the exception of astrovirus:norovirus). # In[39]: from pymc3 import forestplot forestplot(trace, varnames=['β'], ylabels=['intercept'] + formula.split(' + ')); # In[40]: forestplot(trace, varnames=['μ'], ylabels=formula.replace(':', '+').split(' + '), quartiles=False, main='Expected severity'); # In[41]: plt.figure(figsize=(10,6), dpi=300) forestplot(trace, varnames=['main_effects'], ylabels=formula.replace(':', '+').split(' + ')[:4], quartiles=False, plot_kwargs={'linewidth':5, 'markersize':7}) # In[42]: forestplot(trace, varnames=['contrasts'], ylabels=formula.replace(':', '+').split(' + ')[1:], quartiles=False, main='Effect relative to rotavirus only'); # In[43]: from pymc3.stats import _hpd_df def median(x): return pd.Series(np.median(x, 0), name='median') hpd = lambda x: _hpd_df(x, 0.05) # In[44]: from pymc3 import summary summary_table = (summary(trace, varnames=['μ'], stat_funcs=[median, hpd]) .set_index(np.array(formula.split(' + '))) .round(1)[['median', 'hpd_2.5', 'hpd_97.5']]) summary_table.columns = 'median', 'lower 95%', 'upper 95%' summary_table.to_excel('results/predicted_veskari.xlsx') summary_table # In[45]: contrast_table = (summary(trace, varnames=['contrasts'], stat_funcs=[median, hpd]) .set_index(np.array(formula.split(' + '))[1:]) .round(1)[['median', 'hpd_2.5', 'hpd_97.5']]) contrast_table.columns = 'median', 'lower 95%', 'upper 95%' contrast_table.to_excel('results/contrasts.xlsx') contrast_table # ### Posterior predictive checks # In[46]: from pymc3 import sample_ppc with severity_model: severity_ppc = sample_ppc(trace, samples=500) # In[47]: from scipy.stats import percentileofscore plt.hist([np.round(percentileofscore(x, y)/100, 2) for x, y in zip(severity_ppc['y'], severity_data_complete.mvs.values)]) # ## Average severities by virus # In[48]: formula = 'rotavirus + sapovirus + astrovirus + norovirus' X = np.asarray(dmatrix(formula, severity_data_complete)) # In[49]: X_pred = np.array([[1, 1, 0, 0, 0], [1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1]]) # In[50]: with Model() as avg_severity_model: β = Normal('β', 0, sd=10, shape=X.shape[1]) σ = HalfCauchy('σ', 1) y = Normal('y', β.dot(shared(X.T)), sd=σ, observed=severity_data_complete.mvs.values) μ = Deterministic('μ', β.dot(shared(X_pred.T))) avg_trace = sample(1000, tune=2000, chains=1, random_seed=seed_numbers[0]) # In[51]: (summary(avg_trace, varnames=['μ'], stat_funcs=[median, hpd]).round(1) .set_index(np.array(formula.split(' + '))))