%matplotlib inline
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
import theano.tensor as tt
rseed = 20090425, 19700903
hospitalized = pd.read_csv('data/hospitalized.csv', index_col=0, low_memory=False, na_values=['-'])
hospitalized.index.is_unique
True
hospitalized.shape
(3168, 408)
Positive culture lookup
pcr_lookup = {'pcr_result___1': 'RSV',
'pcr_result___2': 'HMPV',
'pcr_result___3': 'flu A',
'pcr_result___4': 'flu B',
'pcr_result___5': 'rhino',
'pcr_result___6': 'PIV1',
'pcr_result___7': 'PIV2',
'pcr_result___8': 'PIV3',
'pcr_result___13': 'H1N1',
'pcr_result___14': 'H3N2',
'pcr_result___15': 'Swine',
'pcr_result___16': 'Swine H1',
'pcr_result___17': 'flu C',
'pcr_result___18': 'Adeno'}
hospitalized['RSV'] = hospitalized.pcr_result___1.astype(bool)
hospitalized['HMPV'] = hospitalized.pcr_result___2.astype(bool)
hospitalized['Rhino'] = hospitalized.pcr_result___5.astype(bool)
hospitalized['Influenza'] = (hospitalized.pcr_result___3 | hospitalized.pcr_result___4).astype(bool)
hospitalized['Adeno'] = hospitalized.pcr_result___18.astype(bool)
hospitalized['PIV'] = (hospitalized.pcr_result___6 | hospitalized.pcr_result___7 | hospitalized.pcr_result___8).astype(bool)
hospitalized['No virus'] = hospitalized[list(pcr_lookup.keys())].sum(1) == 0
hospitalized['All'] = True
viruses = ['RSV', 'HMPV', 'Rhino', 'Influenza', 'Adeno', 'PIV']
def describe_by_var(var, round_to=1):
return(pd.concat({v:hospitalized.loc[hospitalized[v], var]
.describe()
.round(round_to) for v in viruses}, axis=1)[viruses].T)
normalize = lambda x: (x - x.mean())/x.std()
hospitalized['male'] = (hospitalized.sex=='M').astype(int)
hospitalized['premature'] = (hospitalized.gest_age < 37).astype(int)
hospitalized['age_years'] = hospitalized.age_months/12
hospitalized['birth_weight_ctr'] = hospitalized.birth_wt_child - hospitalized.birth_wt_child.mean()
hospitalized['severity_ind'] = (hospitalized[['oxygen',
'vent',
'icu']].sum(1) > 0).astype(int)
hospitalized['hospitalized_vitamin_d'] = hospitalized.hosp_vitd.replace({'QNS': None,
'<1': 0}).astype(float)
Diagnoses
diagnoses = [
'adm_bronchiolitis',
'adm_bronchopneumo',
'adm_pneumo',
'adm_sepsis'
]
hospitalized['any_coinf'] = hospitalized[viruses[:-2]].sum(1)>1
hospitalized.any_coinf.mean()
0.19002525252525251
hospitalized['any_smokers'] = hospitalized[['cigarette_smokers', 'nargila_smokers']].sum(1)
Response variables (for some reason, gestational age does not seem to work)
covs = ['age_years', 'male', 'severity_ind',
'length_of_stay', 'hospitalized_vitamin_d', 'breastfed',
'any_smokers', 'z_score', 'birth_weight_ctr']
analysis_columns = covs + viruses + diagnoses
rsv_dataset = hospitalized.loc[hospitalized.RSV, analysis_columns+['rsv_count']]
rhino_dataset = hospitalized.loc[hospitalized.Rhino, analysis_columns+['rhinovirus_count']]
adeno_dataset = hospitalized.loc[hospitalized.Adeno, analysis_columns+['adenovirus_count']]
hmpv_dataset = hospitalized.loc[hospitalized.HMPV, analysis_columns+['hmpv_count']]
flu_dataset = hospitalized.loc[hospitalized.Influenza, analysis_columns+['influenza_a_count', 'influenza_b_count',
'influenza_c_count']]
flu_dataset['flu_count'] = flu_dataset[['influenza_a_count', 'influenza_b_count',
'influenza_c_count']].fillna(0).sum(1)
flu_dataset = flu_dataset.drop(['influenza_a_count', 'influenza_b_count',
'influenza_c_count'], axis=1)
piv_dataset = hospitalized.loc[hospitalized.PIV, analysis_columns+['piv1_count', 'piv2_count', 'piv3_count']]
piv_dataset['piv_count'] = piv_dataset[['piv1_count', 'piv2_count', 'piv3_count']].fillna(0).sum(1)
rsv_dataset.hospitalized_vitamin_d.hist(bins=20)
<matplotlib.axes._subplots.AxesSubplot at 0x1194e9eb8>
rsv_dataset.plot(y='age_years', x='hospitalized_vitamin_d',
kind='scatter', alpha=0.3)
<matplotlib.axes._subplots.AxesSubplot at 0x1205e0d30>
rsv_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 14 hospitalized_vitamin_d 204 breastfed 0 any_smokers 0 z_score 4 birth_weight_ctr 1 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 rsv_count 0 dtype: int64
Drop those missing z-score and birth weight
rsv_dataset = rsv_dataset.dropna(subset=['z_score', 'birth_weight_ctr'])
RSV model, with indicators for presence of otehr viruses.
def virus_model(dataset, virus_name, output_column, iterations=10000):
use_nuts = iterations<10000
with pm.Model() as _model:
dataset = dataset.dropna(subset=[output_column])
# Prepare data
other_viruses = viruses[:]
other_viruses.remove(virus_name)
X = dataset[other_viruses
+diagnoses
+covs].astype(float)
# Dimensions of data
n, k = X.shape
virus_ct = dataset[output_column].values
virus_centered = virus_ct - virus_ct.mean()
# Extract predictors with missing values
vitamin_D = X.pop('hospitalized_vitamin_d')
exp_vitamin_D = np.exp(vitamin_D)
vitamin_D_norm = normalize(exp_vitamin_D).fillna(99)
los = X.pop('length_of_stay').fillna(999)
# Impute missing vitamin D with 2-component normal mixture
if (vitamin_D_norm==99).sum():
p_high_D = pm.Dirichlet('p_high_D', np.ones(2))
ϕ = pm.Normal('ϕ', 0, sd=1000, shape=2)
τ = pm.Uniform('τ', 0, 100, shape=2)
vitamin_D_imp = pm.NormalMixture('vitamin_D_imp', p_high_D, ϕ, sd=τ,
observed=np.ma.masked_values(vitamin_D_norm, value=99))
else:
vitamin_D_imp = vitamin_D_norm
if (los==999).sum():
ψ = pm.Normal('ψ', 0, sd=1000)
η = pm.Uniform('η', 0, 100)
los_imp = pm.Lognormal('los_imp', ψ, η,
observed=np.ma.masked_values(los, value=999))
else:
los_imp = los
los_norm = normalize(los_imp)
# Predictor coefficients
β = pm.Normal('β', 0, sd=10000, shape=k)
# Mean
μ = pm.Normal('μ', 0, sd=10000)
θ = (μ + tt.dot(X.values, β[:-2]) + β[-2]*los_norm
+ β[-1]*vitamin_D_imp)
# Data likelihood
σ = pm.HalfCauchy('σ', 5)
pm.Normal('likelihood', θ, sd=σ, observed=virus_centered)
if use_nuts:
step = pm.NUTS(target_accept=0.99)
else:
metropolis_vars = [β, μ, σ]
if (vitamin_D_norm==99).sum():
metropolis_vars += [vitamin_D_imp, p_high_D, ϕ, τ]
if (los==999).sum():
metropolis_vars += [los_imp, ψ, η]
step = pm.Metropolis(vars=metropolis_vars)
trace = pm.sample(iterations, step=step, random_seed=rseed[0])
return trace
rsv_trace = virus_model(rsv_dataset, 'RSV', 'rsv_count')
100%|██████████| 10000/10000 [00:29<00:00, 335.11it/s]
Estimates of parameter values. Positive values are associated with higher CT response. Last two parameters are hyperparameters, and should not be interpreted.
rsv_labels = viruses+diagnoses+covs
rsv_labels.remove('RSV')
pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates')
<matplotlib.gridspec.GridSpec at 0x11a0cb438>
pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates')
<matplotlib.gridspec.GridSpec at 0x1236e9400>
pm.traceplot(rsv_trace[-1000:], varnames=['p_high_D'])
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x1184c5a20>, <matplotlib.axes._subplots.AxesSubplot object at 0x119a26390>]], dtype=object)
Here are the imputed vitamin D (normalized) values.
pm.forestplot(rsv_trace[-1000:], varnames=['vitamin_D_imp_missing'])
<matplotlib.gridspec.GridSpec at 0x117a38518>
Rhinovirus model, with indicators for presence of other viruses.
rhino_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 16 hospitalized_vitamin_d 212 breastfed 0 any_smokers 0 z_score 4 birth_weight_ctr 0 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 rhinovirus_count 1 dtype: int64
rhino_trace = virus_model(rhino_dataset.dropna(subset=['z_score', 'rhinovirus_count']),
'Rhino', 'rhinovirus_count')
100%|██████████| 10000/10000 [00:25<00:00, 395.54it/s]
rhino_labels = viruses+diagnoses+covs
rhino_labels.remove('Rhino')
pm.forestplot(rhino_trace[-1000:], varnames=['β'], ylabels=rhino_labels, main='Rhinovirus load covariates')
<matplotlib.gridspec.GridSpec at 0x12394c2e8>
HMPV model, with indicators for presence of other viruses.
hmpv_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 3 hospitalized_vitamin_d 33 breastfed 0 any_smokers 0 z_score 0 birth_weight_ctr 0 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 hmpv_count 1 dtype: int64
hmpv_trace = virus_model(hmpv_dataset.dropna(subset=['hmpv_count']),
'HMPV', 'hmpv_count')
100%|██████████| 10000/10000 [00:17<00:00, 581.15it/s]
hmpv_labels = viruses+diagnoses+covs
hmpv_labels.remove('HMPV')
pm.forestplot(hmpv_trace[-1000:], varnames=['β'], ylabels=hmpv_labels, main='HMPV load covariates')
<matplotlib.gridspec.GridSpec at 0x1219161d0>
Adenovirus model, with indicators for presence of other viruses.
adeno_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 6 hospitalized_vitamin_d 96 breastfed 0 any_smokers 0 z_score 1 birth_weight_ctr 0 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 adenovirus_count 0 dtype: int64
adeno_trace = virus_model(adeno_dataset.dropna(subset=['z_score']),
'Adeno', 'adenovirus_count')
100%|██████████| 10000/10000 [00:21<00:00, 468.50it/s]
adeno_labels = viruses+diagnoses+covs
adeno_labels.remove('Adeno')
pm.forestplot(adeno_trace[-1000:], varnames=['β'], ylabels=adeno_labels, main='Adenovirus load covariates')
<matplotlib.gridspec.GridSpec at 0x11da3d278>
Influenza model, with indicators for presence of other viruses.
flu_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 0 hospitalized_vitamin_d 12 breastfed 0 any_smokers 0 z_score 0 birth_weight_ctr 0 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 flu_count 0 dtype: int64
flu_trace = virus_model(flu_dataset,
'Influenza', 'flu_count')
100%|██████████| 10000/10000 [00:13<00:00, 750.97it/s]
flu_labels = viruses+diagnoses+covs
flu_labels.remove('Influenza')
pm.forestplot(flu_trace[-1000:], varnames=['β'], ylabels=flu_labels, main='Influenza load covariates')
<matplotlib.gridspec.GridSpec at 0x12390c550>
PIV model, with indicators for presence of other viruses.
piv_dataset.isnull().sum()
age_years 0 male 0 severity_ind 0 length_of_stay 0 hospitalized_vitamin_d 30 breastfed 0 any_smokers 0 z_score 0 birth_weight_ctr 0 RSV 0 HMPV 0 Rhino 0 Influenza 0 Adeno 0 PIV 0 adm_bronchiolitis 0 adm_bronchopneumo 0 adm_pneumo 0 adm_sepsis 0 piv1_count 141 piv2_count 162 piv3_count 47 piv_count 0 dtype: int64
piv_trace = virus_model(piv_dataset, 'PIV', 'piv_count')
100%|██████████| 10000/10000 [00:11<00:00, 834.07it/s]
piv_labels = viruses+diagnoses+covs
piv_labels.remove('PIV')
pm.forestplot(piv_trace[-1000:], varnames=['β'], ylabels=piv_labels, main='PIV load covariates')
<matplotlib.gridspec.GridSpec at 0x12255b4e0>