%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pylab as plt
sns.set_context('notebook')
seed_numbers = 20090425, 19700903
Import current dataset
data_file = 'data/AGE Cases and Healthy Controls Merged with Lab Results Year 2 to 4 v05.23.2017.csv'
rotavirus_data = (pd.read_csv(data_file, low_memory=False, na_values=[88, 99, 777, 888, 999])
.rename(columns={'rotacdc':'rotavirus',
'sapo':'sapovirus',
'astro':'astrovirus',
'noro':'norovirus'}))
(rotavirus_data.houtcome==0).sum()
1
rotavirus_data['hicu'].sum()
5.0
ct_cols = ['sapo_ct', 'astro_ct', 'noro_ct']
virus_cols = ['rotavirus','sapovirus','astrovirus','norovirus']
Specify columns needed to calculate score, and restrict to cases
rotavirus_data['virus_pos'] = rotavirus_data[['rotavirus', 'sapovirus', 'astrovirus', 'norovirus']].sum(1) > 0
severity_cols = ['caseid',
'virus_pos',
'daysill',
'diarrh',
'diarrhstart',
'diarrhdays',
'diarrhepiso',
'vomit',
'vomitstart',
'vomitdays',
'vomitepiso',
'fever',
'feverstart',
'feverhigh',
'feverdays',
'feverdecimal',
'measureby',
'measurothsp',
'newvisit',
'irtherapy',
'irtherapydur',
'settingnew',
'provider',
'outadm',
'inpelig',
'surday',
'oralrehydra',
'behave',
'eyes',
'takefluids',
'skintest']
severity_data = rotavirus_data.loc[rotavirus_data.case==1, severity_cols + virus_cols + ct_cols]
severity_data['coinfection'] = (severity_data[virus_cols].sum(axis=1) > 1).astype(int)
severity_data['care_level'] = None
severity_data.loc[severity_data.provider==1, 'care_level'] = 1
severity_data.loc[severity_data.provider==2, 'care_level'] = 3
severity_data.loc[severity_data.provider==3, 'care_level'] = 4
severity_data.loc[severity_data.provider.isin([2,3]) &
(severity_data.outadm==1) &
(severity_data.inpelig==1) &
(severity_data.surday==1), 'care_level'] = 2
severity_data.care_level.value_counts()
3 2179 4 1274 1 155 2 97 Name: care_level, dtype: int64
Treatment variable
severity_data['treatment'] = 0
severity_data.loc[severity_data.care_level.isin([1,2]), 'treatment'] = 2
severity_data.loc[severity_data.care_level.isin([3,4]) &
(severity_data.oralrehydra==1) &
(severity_data.irtherapy==1), 'treatment'] = 1
severity_data['fever_high_c'] = severity_data.feverhigh.apply(lambda f: (f - 32) * 5 / 9)
Dehydration severity
severe = ((severity_data.eyes==2).astype(int)
+ (severity_data.behave>3).astype(int)
+ (severity_data.skintest==3).astype(int)
+ (severity_data.takefluids==2).astype(int)) >= 2
moderate = ((severity_data.eyes==2).astype(int)
+ (severity_data.behave>2).astype(int)
+ (severity_data.skintest>=2).astype(int)
+ (severity_data.takefluids>=1).astype(int)) >= 2
dehydration = severe.astype(int) + (moderate.astype(int)*2)
dehydration.value_counts()
2 1736 0 1004 3 965 dtype: int64
Visualization of score components
(severity_data.groupby('virus_pos')[['diarrhdays', 'diarrhepiso', 'vomitdays', 'vomitepiso', 'fever_high_c']].median())
diarrhdays | diarrhepiso | vomitdays | vomitepiso | fever_high_c | |
---|---|---|---|---|---|
virus_pos | |||||
False | 2.0 | 4.0 | 2.0 | 3.0 | 38.888889 |
True | 2.0 | 5.0 | 2.0 | 4.0 | 38.888889 |
severity_data[['diarrhdays', 'diarrhepiso', 'vomitdays', 'vomitepiso', 'fever_high_c']].describe()
diarrhdays | diarrhepiso | vomitdays | vomitepiso | fever_high_c | |
---|---|---|---|---|---|
count | 2462.000000 | 2404.000000 | 3101.000000 | 3046.000000 | 1758.000000 |
mean | 2.835906 | 5.706739 | 2.049661 | 4.688116 | 38.853179 |
std | 1.873456 | 5.253945 | 1.391721 | 4.723547 | 0.815487 |
min | 1.000000 | 0.000000 | 1.000000 | 1.000000 | 36.666667 |
25% | 1.000000 | 3.000000 | 1.000000 | 2.000000 | 38.333333 |
50% | 2.000000 | 5.000000 | 2.000000 | 4.000000 | 38.888889 |
75% | 4.000000 | 7.000000 | 3.000000 | 6.000000 | 39.444444 |
max | 10.000000 | 77.000000 | 10.000000 | 77.000000 | 42.222222 |
severity_data.diarrhdays.hist()
plt.vlines([4,6], ymin=0, ymax=700)
plt.xlabel('Diarrhea days')
Text(0.5,0,'Diarrhea days')
severity_data.diarrhepiso.hist()
plt.vlines([3,6], ymin=0, ymax=1800)
plt.xlim(0, 30)
plt.xlabel('Diarrhea episodes')
Text(0.5,0,'Diarrhea episodes')
severity_data.vomitdays.hist()
plt.vlines([1,3], ymin=0, ymax=1400)
plt.xlim(0, 10)
plt.xlabel('Vomit days')
Text(0.5,0,'Vomit days')
severity_data.vomitepiso.hist(bins=20)
plt.vlines([1,5], ymin=0, ymax=2000)
plt.xlim(0, 30)
plt.xlabel('Vomit episodes')
Text(0.5,0,'Vomit episodes')
severity_data.fever_high_c.hist(bins=10)
plt.vlines([38.4, 39], ymin=0, ymax=500)
plt.xlabel('Max. Temperature')
Text(0.5,0,'Max. Temperature')
Functions for scoring each component
def diarrh_duration_points(x):
if (x==0) or np.isnan(x):
return 0
elif 1 <= x <= 4:
return 1
elif 4 < x <= 5:
return 2
elif x >= 6:
return 3
else:
raise ValueError('Invalid diarrhea duration value:', x)
def diarrh_episode_points(x):
if (x==0) or np.isnan(x):
return 0
elif 1 <= x < 4:
return 1
elif 4 <= x <= 5:
return 2
elif x > 5:
return 3
else:
raise ValueError('Invalid stool value:', x)
def vomit_duration_points(x):
if (x==0) or np.isnan(x):
return 0
elif x == 1:
return 1
elif x == 2:
return 2
elif x > 2:
return 3
else:
raise ValueError('Invalid vomit duration value:', x)
def vomit_episode_points(x):
if (x==0) or np.isnan(x):
return 0
elif x==1:
return 1
elif 2 <= x <= 4:
return 2
elif x > 4:
return 3
else:
raise ValueError('Invalid vomit episode value:', x)
def fever_points(x):
if (x < 37) or np.isnan(x):
return 0
elif 37.1 <= x <= 38.4:
return 1
elif 38.5 <= x <= 38.9:
return 2
elif x >= 39:
return 3
else:
raise ValueError('Invalid fever value:', x)
Calculate points for each criterion
diarrh_duration = (severity_data.assign(diarrhhours=severity_data.diarrhdays)
.diarrhhours
.apply(diarrh_duration_points))
diarrh_max_episodes = (severity_data.diarrhepiso.apply(diarrh_episode_points))
vomit_duration = (severity_data.assign(vomithours=severity_data.vomitdays)
.vomithours
.apply(vomit_duration_points))
vomit_max_episodes = (severity_data.vomitepiso.apply(vomit_episode_points))
fever_high = (severity_data.fever_high_c.apply(fever_points))
emergency = (severity_data.newvisit==1) * 3
# rehydration = (severity_data.irtherapydur==1).astype(int)
hospitalization = 2*(severity_data.settingnew==1).astype(int)
Sum points to get score
veskari_subset = (diarrh_duration + diarrh_max_episodes +
vomit_duration + vomit_max_episodes +
fever_high)
veskari_score = (diarrh_duration + diarrh_max_episodes +
vomit_duration + vomit_max_episodes +
fever_high + emergency +
dehydration + hospitalization)
ax = sns.distplot(veskari_score, bins=veskari_score.max()-1, axlabel='MVS')
text = "mean={}\nstd={}".format(np.round(veskari_score.mean(), 2),
np.round(veskari_score.std(), 2))
ax.text(0, 0.12, text)
Text(0,0.12,'mean=8.32\nstd=3.16')
By comparison, here is the distribution in the Schnadower et al. paper that validates the MVS:
severity_data['mvs'] = veskari_score
severity_data['mvs_symptoms'] = veskari_subset
severity_data.head()
caseid | virus_pos | daysill | diarrh | diarrhstart | diarrhdays | diarrhepiso | vomit | vomitstart | vomitdays | ... | norovirus | sapo_ct | astro_ct | noro_ct | coinfection | care_level | treatment | fever_high_c | mvs | mvs_symptoms | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | EN1C0001 | False | 2.0 | 0.0 | NaN | NaN | NaN | 1.0 | 2.0 | 2.0 | ... | 0.0 | NaN | NaN | NaN | 0 | 4 | 0 | 39.444444 | 6 | 6 |
1 | EN1C0002 | False | 5.0 | 1.0 | 4.0 | 3.0 | 2.0 | 1.0 | 5.0 | 5.0 | ... | NaN | NaN | NaN | NaN | 0 | 4 | 0 | NaN | 9 | 7 |
2 | EN1C0003 | False | 3.0 | 1.0 | 3.0 | 3.0 | 3.0 | 0.0 | NaN | NaN | ... | NaN | NaN | NaN | NaN | 0 | 4 | 0 | NaN | 2 | 2 |
3 | EN1C0004 | False | 4.0 | 0.0 | NaN | NaN | NaN | 1.0 | 4.0 | 2.0 | ... | 0.0 | NaN | NaN | NaN | 0 | 4 | 0 | NaN | 4 | 4 |
4 | EN1C0006 | False | 3.0 | 0.0 | NaN | NaN | NaN | 1.0 | 3.0 | 3.0 | ... | NaN | NaN | NaN | NaN | 0 | 4 | 0 | NaN | 7 | 5 |
5 rows × 44 columns
severity_data.to_excel('results/severity.xlsx')
severity_data_complete = severity_data.dropna(subset=virus_cols)
Drop those with no virus testing, and merge with severity score table
severity_data_complete.shape
(2886, 44)
settings = {1:'Inpatient', 2:'ED', 3:'Outpatient'}
combinations = [(virus, setting+1) for virus in ('rotavirus', 'sapovirus', 'astrovirus', 'norovirus')
for setting in range(3)]
fig, axes = plt.subplots(4, 3, figsize=(12,8), sharex=True, sharey=True)
plt.tight_layout()
for ax, (virus, setting) in zip(axes.ravel(), combinations):
sns.distplot(severity_data_complete.loc[(severity_data_complete.settingnew==setting) &
severity_data[virus]==1, 'mvs'],
ax=ax,
bins=10,
axlabel='{} {}'.format(settings[setting], virus));
from patsy import dmatrix
import theano.tensor as tt
from theano import shared
WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
formula = 'rotavirus + sapovirus + astrovirus + norovirus'
formula += ' + rotavirus:sapovirus + rotavirus:astrovirus + rotavirus:norovirus'
formula += ' + sapovirus:astrovirus + sapovirus:norovirus + astrovirus:norovirus'
X = np.asarray(dmatrix(formula, severity_data_complete))
X_pred = np.array([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
[1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0],
[1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1]])
from pymc3 import Model, sample, Normal, HalfCauchy, Deterministic
with Model() as severity_model:
β = Normal('β', 0, sd=10, shape=X.shape[1])
σ = HalfCauchy('σ', 1)
y = Normal('y', β.dot(shared(X.T)), sd=σ, observed=severity_data_complete.mvs.values)
μ = Deterministic('μ', β.dot(shared(X_pred.T)))
main_effects = Deterministic('main_effects', μ[:4])
contrasts = Deterministic('contrasts', μ[1:] - μ[0])
trace = sample(1000, tune=2000, chains=1, random_seed=seed_numbers[0])
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [σ, β] 100%|██████████| 3000/3000 [00:23<00:00, 127.38it/s] Only one chain was sampled, this makes it impossible to run some convergence checks
While severity tends to increase with rotavirus, sapovirus and astrovirus infection, coinfection tends to reduce severity relative to main effects (with the exception of astrovirus:norovirus).
from pymc3 import forestplot
forestplot(trace, varnames=['β'], ylabels=['intercept'] + formula.split(' + '));
forestplot(trace, varnames=['μ'], ylabels=formula.replace(':', '+').split(' + '), quartiles=False,
main='Expected severity');
plt.figure(figsize=(10,6), dpi=300)
forestplot(trace, varnames=['main_effects'], ylabels=formula.replace(':', '+').split(' + ')[:4],
quartiles=False, plot_kwargs={'linewidth':5, 'markersize':7})
<matplotlib.gridspec.GridSpec at 0x7fe6e1494e10>
forestplot(trace, varnames=['contrasts'], ylabels=formula.replace(':', '+').split(' + ')[1:], quartiles=False,
main='Effect relative to rotavirus only');
from pymc3.stats import _hpd_df
def median(x):
return pd.Series(np.median(x, 0), name='median')
hpd = lambda x: _hpd_df(x, 0.05)
from pymc3 import summary
summary_table = (summary(trace, varnames=['μ'], stat_funcs=[median, hpd])
.set_index(np.array(formula.split(' + ')))
.round(1)[['median', 'hpd_2.5', 'hpd_97.5']])
summary_table.columns = 'median', 'lower 95%', 'upper 95%'
summary_table.to_excel('results/predicted_veskari.xlsx')
summary_table
median | lower 95% | upper 95% | |
---|---|---|---|
rotavirus | 10.2 | 9.8 | 10.6 |
sapovirus | 8.6 | 8.2 | 9.0 |
astrovirus | 8.3 | 7.7 | 9.0 |
norovirus | 8.4 | 8.2 | 8.7 |
rotavirus:sapovirus | 10.4 | 9.0 | 11.8 |
rotavirus:astrovirus | 10.6 | 8.9 | 12.7 |
rotavirus:norovirus | 9.6 | 8.4 | 10.8 |
sapovirus:astrovirus | 7.5 | 6.2 | 8.7 |
sapovirus:norovirus | 8.7 | 7.7 | 10.0 |
astrovirus:norovirus | 9.8 | 8.6 | 11.1 |
contrast_table = (summary(trace, varnames=['contrasts'], stat_funcs=[median, hpd])
.set_index(np.array(formula.split(' + '))[1:])
.round(1)[['median', 'hpd_2.5', 'hpd_97.5']])
contrast_table.columns = 'median', 'lower 95%', 'upper 95%'
contrast_table.to_excel('results/contrasts.xlsx')
contrast_table
median | lower 95% | upper 95% | |
---|---|---|---|
sapovirus | -1.6 | -2.2 | -1.1 |
astrovirus | -1.9 | -2.7 | -1.1 |
norovirus | -1.8 | -2.2 | -1.3 |
rotavirus:sapovirus | 0.2 | -1.3 | 1.7 |
rotavirus:astrovirus | 0.4 | -1.5 | 2.3 |
rotavirus:norovirus | -0.7 | -1.8 | 0.6 |
sapovirus:astrovirus | -2.8 | -4.1 | -1.5 |
sapovirus:norovirus | -1.5 | -2.6 | -0.2 |
astrovirus:norovirus | -0.4 | -1.8 | 0.8 |
from pymc3 import sample_ppc
with severity_model:
severity_ppc = sample_ppc(trace, samples=500)
100%|██████████| 500/500 [00:00<00:00, 546.42it/s]
from scipy.stats import percentileofscore
plt.hist([np.round(percentileofscore(x, y)/100, 2) for x, y in zip(severity_ppc['y'], severity_data_complete.mvs.values)])
(array([68., 74., 48., 47., 51., 63., 42., 40., 29., 38.]), array([0.02 , 0.118, 0.216, 0.314, 0.412, 0.51 , 0.608, 0.706, 0.804, 0.902, 1. ]), <a list of 10 Patch objects>)
formula = 'rotavirus + sapovirus + astrovirus + norovirus'
X = np.asarray(dmatrix(formula, severity_data_complete))
X_pred = np.array([[1, 1, 0, 0, 0],
[1, 0, 1, 0, 0],
[1, 0, 0, 1, 0],
[1, 0, 0, 0, 1]])
with Model() as avg_severity_model:
β = Normal('β', 0, sd=10, shape=X.shape[1])
σ = HalfCauchy('σ', 1)
y = Normal('y', β.dot(shared(X.T)), sd=σ, observed=severity_data_complete.mvs.values)
μ = Deterministic('μ', β.dot(shared(X_pred.T)))
avg_trace = sample(1000, tune=2000, chains=1, random_seed=seed_numbers[0])
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [σ, β] 100%|██████████| 3000/3000 [00:13<00:00, 229.53it/s] Only one chain was sampled, this makes it impossible to run some convergence checks
(summary(avg_trace, varnames=['μ'], stat_funcs=[median, hpd]).round(1)
.set_index(np.array(formula.split(' + '))))
median | hpd_2.5 | hpd_97.5 | |
---|---|---|---|
rotavirus | 10.1 | 9.7 | 10.5 |
sapovirus | 8.5 | 8.2 | 8.9 |
astrovirus | 8.3 | 7.8 | 8.9 |
norovirus | 8.4 | 8.2 | 8.7 |