In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
import theano.tensor as tt

rseed = 20090425, 19700903

Import Data

In [2]:
hospitalized = pd.read_csv('data/hospitalized.csv', index_col=0, low_memory=False, na_values=['-'])
hospitalized.index.is_unique
Out[2]:
True
In [3]:
hospitalized.shape
Out[3]:
(3168, 408)

Positive culture lookup

In [4]:
pcr_lookup = {'pcr_result___1': 'RSV',
'pcr_result___2': 'HMPV',
'pcr_result___3': 'flu A',
'pcr_result___4': 'flu B',
'pcr_result___5': 'rhino',
'pcr_result___6': 'PIV1',
'pcr_result___7': 'PIV2',
'pcr_result___8': 'PIV3',
'pcr_result___13': 'H1N1',
'pcr_result___14': 'H3N2',
'pcr_result___15': 'Swine',
'pcr_result___16': 'Swine H1',
'pcr_result___17': 'flu C',
'pcr_result___18': 'Adeno'}
In [5]:
hospitalized['RSV'] = hospitalized.pcr_result___1.astype(bool)
hospitalized['HMPV'] = hospitalized.pcr_result___2.astype(bool)
hospitalized['Rhino'] = hospitalized.pcr_result___5.astype(bool)
hospitalized['Influenza'] = (hospitalized.pcr_result___3 | hospitalized.pcr_result___4).astype(bool)
hospitalized['Adeno'] = hospitalized.pcr_result___18.astype(bool)
hospitalized['PIV'] = (hospitalized.pcr_result___6 | hospitalized.pcr_result___7 | hospitalized.pcr_result___8).astype(bool)
hospitalized['No virus'] = hospitalized[list(pcr_lookup.keys())].sum(1) == 0
hospitalized['All'] = True
In [6]:
viruses = ['RSV', 'HMPV', 'Rhino', 'Influenza', 'Adeno', 'PIV']
In [7]:
def describe_by_var(var, round_to=1):
    return(pd.concat({v:hospitalized.loc[hospitalized[v], var]
                      .describe()
                      .round(round_to) for v in viruses}, axis=1)[viruses].T)
In [8]:
normalize = lambda x: (x - x.mean())/x.std()
In [100]:
hospitalized['male'] = (hospitalized.sex=='M').astype(int)
hospitalized['premature'] = (hospitalized.gest_age < 37).astype(int)
hospitalized['age_years'] = hospitalized.age_months/12
hospitalized['birth_weight_ctr'] = hospitalized.birth_wt_child - hospitalized.birth_wt_child.mean()
In [101]:
hospitalized['severity_ind'] = (hospitalized[['oxygen', 
                                              'vent', 
                                              'icu']].sum(1) > 0).astype(int)
In [102]:
hospitalized['hospitalized_vitamin_d'] = hospitalized.hosp_vitd.replace({'QNS': None, 
                                        '<1': 0}).astype(float)

Diagnoses

In [103]:
diagnoses = [
    'adm_bronchiolitis',
    'adm_bronchopneumo',
    'adm_pneumo',
    'adm_sepsis'
]
In [104]:
hospitalized['any_coinf'] = hospitalized[viruses[:-2]].sum(1)>1
hospitalized.any_coinf.mean()
Out[104]:
0.19002525252525251
In [105]:
hospitalized['any_smokers'] = hospitalized[['cigarette_smokers', 'nargila_smokers']].sum(1)

Response variables (for some reason, gestational age does not seem to work)

In [106]:
covs = ['age_years', 'male', 'severity_ind', 
      'length_of_stay', 'hospitalized_vitamin_d', 'breastfed',
       'any_smokers', 'z_score', 'birth_weight_ctr']
In [107]:
analysis_columns = covs + viruses + diagnoses
In [134]:
rsv_dataset = hospitalized.loc[hospitalized.RSV, analysis_columns+['rsv_count']]
rhino_dataset = hospitalized.loc[hospitalized.Rhino, analysis_columns+['rhinovirus_count']]
adeno_dataset = hospitalized.loc[hospitalized.Adeno, analysis_columns+['adenovirus_count']]
hmpv_dataset = hospitalized.loc[hospitalized.HMPV, analysis_columns+['hmpv_count']]
flu_dataset = hospitalized.loc[hospitalized.Influenza, analysis_columns+['influenza_a_count', 'influenza_b_count',
       'influenza_c_count']]
flu_dataset['flu_count'] = flu_dataset[['influenza_a_count', 'influenza_b_count',
       'influenza_c_count']].fillna(0).sum(1)
flu_dataset = flu_dataset.drop(['influenza_a_count', 'influenza_b_count',
       'influenza_c_count'], axis=1)
piv_dataset = hospitalized.loc[hospitalized.PIV, analysis_columns+['piv1_count', 'piv2_count', 'piv3_count']]
piv_dataset['piv_count'] = piv_dataset[['piv1_count', 'piv2_count', 'piv3_count']].fillna(0).sum(1)
In [109]:
rsv_dataset.hospitalized_vitamin_d.hist(bins=20)
Out[109]:
<matplotlib.axes._subplots.AxesSubplot at 0x1194e9eb8>
In [110]:
rsv_dataset.plot(y='age_years', x='hospitalized_vitamin_d', 
                 kind='scatter', alpha=0.3)
Out[110]:
<matplotlib.axes._subplots.AxesSubplot at 0x1205e0d30>
In [111]:
rsv_dataset.isnull().sum()
Out[111]:
age_years                   0
male                        0
severity_ind                0
length_of_stay             14
hospitalized_vitamin_d    204
breastfed                   0
any_smokers                 0
z_score                     4
birth_weight_ctr            1
RSV                         0
HMPV                        0
Rhino                       0
Influenza                   0
Adeno                       0
PIV                         0
adm_bronchiolitis           0
adm_bronchopneumo           0
adm_pneumo                  0
adm_sepsis                  0
rsv_count                   0
dtype: int64

Drop those missing z-score and birth weight

In [114]:
rsv_dataset = rsv_dataset.dropna(subset=['z_score', 'birth_weight_ctr'])

Specify Models

RSV model, with indicators for presence of otehr viruses.

In [115]:
def virus_model(dataset, virus_name, output_column, iterations=10000):
    
    use_nuts = iterations<10000
    
    with pm.Model() as _model:
        
        dataset = dataset.dropna(subset=[output_column])

        # Prepare data
        other_viruses = viruses[:]
        other_viruses.remove(virus_name)
        X = dataset[other_viruses
                    +diagnoses
                    +covs].astype(float)
        # Dimensions of data
        n, k = X.shape

        virus_ct = dataset[output_column].values
        virus_centered = virus_ct - virus_ct.mean()
        
        # Extract predictors with missing values
        vitamin_D = X.pop('hospitalized_vitamin_d')
        exp_vitamin_D = np.exp(vitamin_D)
        vitamin_D_norm = normalize(exp_vitamin_D).fillna(99)
        
        los = X.pop('length_of_stay').fillna(999)

        # Impute missing vitamin D with 2-component normal mixture
        if (vitamin_D_norm==99).sum():
            p_high_D = pm.Dirichlet('p_high_D', np.ones(2))
            ϕ = pm.Normal('ϕ', 0, sd=1000, shape=2)
            τ = pm.Uniform('τ', 0, 100, shape=2)

            vitamin_D_imp = pm.NormalMixture('vitamin_D_imp', p_high_D, ϕ, sd=τ, 
                             observed=np.ma.masked_values(vitamin_D_norm, value=99))
        else:
            vitamin_D_imp = vitamin_D_norm

            
        if (los==999).sum():
            ψ = pm.Normal('ψ', 0, sd=1000)
            η = pm.Uniform('η', 0, 100)
            los_imp = pm.Lognormal('los_imp', ψ, η,
                                  observed=np.ma.masked_values(los, value=999))

        else:
            los_imp = los
            
        los_norm = normalize(los_imp)
            
        # Predictor coefficients
        β = pm.Normal('β', 0, sd=10000, shape=k)

        # Mean
        μ = pm.Normal('μ', 0, sd=10000)

        θ = (μ + tt.dot(X.values, β[:-2]) + β[-2]*los_norm 
                        + β[-1]*vitamin_D_imp)

        # Data likelihood
        σ = pm.HalfCauchy('σ', 5)
        pm.Normal('likelihood', θ, sd=σ, observed=virus_centered)
        
        if use_nuts:
            step = pm.NUTS(target_accept=0.99)
        else:
            metropolis_vars = [β, μ, σ]
            if (vitamin_D_norm==99).sum():
                metropolis_vars += [vitamin_D_imp, p_high_D, ϕ, τ]
            if (los==999).sum():
                metropolis_vars += [los_imp, ψ, η]
            step = pm.Metropolis(vars=metropolis_vars)
        trace = pm.sample(iterations, step=step, random_seed=rseed[0])
        
    return trace
In [116]:
rsv_trace = virus_model(rsv_dataset, 'RSV', 'rsv_count')
100%|██████████| 10000/10000 [00:29<00:00, 335.11it/s]

Estimates of parameter values. Positive values are associated with higher CT response. Last two parameters are hyperparameters, and should not be interpreted.

  • HMPV coinfection results in 5-10 point higher CT
  • Rhinovirus coinfection results in 0.5-1.5 point higher CT
  • each year of age beyond newborn associated with 1-2 point higher CT
In [117]:
rsv_labels = viruses+diagnoses+covs
rsv_labels.remove('RSV')
In [48]:
pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates')
Out[48]:
<matplotlib.gridspec.GridSpec at 0x11a0cb438>
In [118]:
pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates')
Out[118]:
<matplotlib.gridspec.GridSpec at 0x1236e9400>
In [119]:
pm.traceplot(rsv_trace[-1000:], varnames=['p_high_D'])
Out[119]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x1184c5a20>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x119a26390>]], dtype=object)

Here are the imputed vitamin D (normalized) values.

In [120]:
pm.forestplot(rsv_trace[-1000:], varnames=['vitamin_D_imp_missing'])
Out[120]:
<matplotlib.gridspec.GridSpec at 0x117a38518>

Rhinovirus model, with indicators for presence of other viruses.

In [121]:
rhino_dataset.isnull().sum()
Out[121]:
age_years                   0
male                        0
severity_ind                0
length_of_stay             16
hospitalized_vitamin_d    212
breastfed                   0
any_smokers                 0
z_score                     4
birth_weight_ctr            0
RSV                         0
HMPV                        0
Rhino                       0
Influenza                   0
Adeno                       0
PIV                         0
adm_bronchiolitis           0
adm_bronchopneumo           0
adm_pneumo                  0
adm_sepsis                  0
rhinovirus_count            1
dtype: int64
In [122]:
rhino_trace = virus_model(rhino_dataset.dropna(subset=['z_score', 'rhinovirus_count']), 
                          'Rhino', 'rhinovirus_count')
100%|██████████| 10000/10000 [00:25<00:00, 395.54it/s]
In [123]:
rhino_labels = viruses+diagnoses+covs
rhino_labels.remove('Rhino')
In [124]:
pm.forestplot(rhino_trace[-1000:], varnames=['β'], ylabels=rhino_labels, main='Rhinovirus load covariates')
Out[124]:
<matplotlib.gridspec.GridSpec at 0x12394c2e8>

HMPV model, with indicators for presence of other viruses.

In [125]:
hmpv_dataset.isnull().sum()
Out[125]:
age_years                  0
male                       0
severity_ind               0
length_of_stay             3
hospitalized_vitamin_d    33
breastfed                  0
any_smokers                0
z_score                    0
birth_weight_ctr           0
RSV                        0
HMPV                       0
Rhino                      0
Influenza                  0
Adeno                      0
PIV                        0
adm_bronchiolitis          0
adm_bronchopneumo          0
adm_pneumo                 0
adm_sepsis                 0
hmpv_count                 1
dtype: int64
In [126]:
hmpv_trace = virus_model(hmpv_dataset.dropna(subset=['hmpv_count']), 
                         'HMPV', 'hmpv_count')
100%|██████████| 10000/10000 [00:17<00:00, 581.15it/s]
In [127]:
hmpv_labels = viruses+diagnoses+covs
hmpv_labels.remove('HMPV')
In [128]:
pm.forestplot(hmpv_trace[-1000:], varnames=['β'], ylabels=hmpv_labels, main='HMPV load covariates')
Out[128]:
<matplotlib.gridspec.GridSpec at 0x1219161d0>

Adenovirus model, with indicators for presence of other viruses.

In [129]:
adeno_dataset.isnull().sum()
Out[129]:
age_years                  0
male                       0
severity_ind               0
length_of_stay             6
hospitalized_vitamin_d    96
breastfed                  0
any_smokers                0
z_score                    1
birth_weight_ctr           0
RSV                        0
HMPV                       0
Rhino                      0
Influenza                  0
Adeno                      0
PIV                        0
adm_bronchiolitis          0
adm_bronchopneumo          0
adm_pneumo                 0
adm_sepsis                 0
adenovirus_count           0
dtype: int64
In [130]:
adeno_trace = virus_model(adeno_dataset.dropna(subset=['z_score']), 
                          'Adeno', 'adenovirus_count')
100%|██████████| 10000/10000 [00:21<00:00, 468.50it/s]
In [131]:
adeno_labels = viruses+diagnoses+covs
adeno_labels.remove('Adeno')
In [132]:
pm.forestplot(adeno_trace[-1000:], varnames=['β'], ylabels=adeno_labels, main='Adenovirus load covariates')
Out[132]:
<matplotlib.gridspec.GridSpec at 0x11da3d278>

Influenza model, with indicators for presence of other viruses.

In [135]:
flu_dataset.isnull().sum()
Out[135]:
age_years                  0
male                       0
severity_ind               0
length_of_stay             0
hospitalized_vitamin_d    12
breastfed                  0
any_smokers                0
z_score                    0
birth_weight_ctr           0
RSV                        0
HMPV                       0
Rhino                      0
Influenza                  0
Adeno                      0
PIV                        0
adm_bronchiolitis          0
adm_bronchopneumo          0
adm_pneumo                 0
adm_sepsis                 0
flu_count                  0
dtype: int64
In [136]:
flu_trace = virus_model(flu_dataset, 
                        'Influenza', 'flu_count')
100%|██████████| 10000/10000 [00:13<00:00, 750.97it/s]
In [137]:
flu_labels = viruses+diagnoses+covs
flu_labels.remove('Influenza')
In [138]:
pm.forestplot(flu_trace[-1000:], varnames=['β'], ylabels=flu_labels, main='Influenza load covariates')
Out[138]:
<matplotlib.gridspec.GridSpec at 0x12390c550>

PIV model, with indicators for presence of other viruses.

In [139]:
piv_dataset.isnull().sum()
Out[139]:
age_years                   0
male                        0
severity_ind                0
length_of_stay              0
hospitalized_vitamin_d     30
breastfed                   0
any_smokers                 0
z_score                     0
birth_weight_ctr            0
RSV                         0
HMPV                        0
Rhino                       0
Influenza                   0
Adeno                       0
PIV                         0
adm_bronchiolitis           0
adm_bronchopneumo           0
adm_pneumo                  0
adm_sepsis                  0
piv1_count                141
piv2_count                162
piv3_count                 47
piv_count                   0
dtype: int64
In [140]:
piv_trace = virus_model(piv_dataset, 'PIV', 'piv_count')
100%|██████████| 10000/10000 [00:11<00:00, 834.07it/s]
In [141]:
piv_labels = viruses+diagnoses+covs
piv_labels.remove('PIV')
In [142]:
pm.forestplot(piv_trace[-1000:], varnames=['β'], ylabels=piv_labels, main='PIV load covariates')
Out[142]:
<matplotlib.gridspec.GridSpec at 0x12255b4e0>