#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import pandas as pd import pymc3 as pm import seaborn as sns import theano.tensor as tt rseed = 20090425, 19700903 # ## Import Data # In[2]: hospitalized = pd.read_csv('data/hospitalized.csv', index_col=0, low_memory=False, na_values=['-']) hospitalized.index.is_unique # In[3]: hospitalized.shape # Positive culture lookup # In[4]: pcr_lookup = {'pcr_result___1': 'RSV', 'pcr_result___2': 'HMPV', 'pcr_result___3': 'flu A', 'pcr_result___4': 'flu B', 'pcr_result___5': 'rhino', 'pcr_result___6': 'PIV1', 'pcr_result___7': 'PIV2', 'pcr_result___8': 'PIV3', 'pcr_result___13': 'H1N1', 'pcr_result___14': 'H3N2', 'pcr_result___15': 'Swine', 'pcr_result___16': 'Swine H1', 'pcr_result___17': 'flu C', 'pcr_result___18': 'Adeno'} # In[5]: hospitalized['RSV'] = hospitalized.pcr_result___1.astype(bool) hospitalized['HMPV'] = hospitalized.pcr_result___2.astype(bool) hospitalized['Rhino'] = hospitalized.pcr_result___5.astype(bool) hospitalized['Influenza'] = (hospitalized.pcr_result___3 | hospitalized.pcr_result___4).astype(bool) hospitalized['Adeno'] = hospitalized.pcr_result___18.astype(bool) hospitalized['PIV'] = (hospitalized.pcr_result___6 | hospitalized.pcr_result___7 | hospitalized.pcr_result___8).astype(bool) hospitalized['No virus'] = hospitalized[list(pcr_lookup.keys())].sum(1) == 0 hospitalized['All'] = True # In[6]: viruses = ['RSV', 'HMPV', 'Rhino', 'Influenza', 'Adeno', 'PIV'] # In[7]: def describe_by_var(var, round_to=1): return(pd.concat({v:hospitalized.loc[hospitalized[v], var] .describe() .round(round_to) for v in viruses}, axis=1)[viruses].T) # In[8]: normalize = lambda x: (x - x.mean())/x.std() # In[100]: hospitalized['male'] = (hospitalized.sex=='M').astype(int) hospitalized['premature'] = (hospitalized.gest_age < 37).astype(int) hospitalized['age_years'] = hospitalized.age_months/12 hospitalized['birth_weight_ctr'] = hospitalized.birth_wt_child - hospitalized.birth_wt_child.mean() # In[101]: hospitalized['severity_ind'] = (hospitalized[['oxygen', 'vent', 'icu']].sum(1) > 0).astype(int) # In[102]: hospitalized['hospitalized_vitamin_d'] = hospitalized.hosp_vitd.replace({'QNS': None, '<1': 0}).astype(float) # Diagnoses # In[103]: diagnoses = [ 'adm_bronchiolitis', 'adm_bronchopneumo', 'adm_pneumo', 'adm_sepsis' ] # In[104]: hospitalized['any_coinf'] = hospitalized[viruses[:-2]].sum(1)>1 hospitalized.any_coinf.mean() # In[105]: hospitalized['any_smokers'] = hospitalized[['cigarette_smokers', 'nargila_smokers']].sum(1) # Response variables (for some reason, gestational age does not seem to work) # In[106]: covs = ['age_years', 'male', 'severity_ind', 'length_of_stay', 'hospitalized_vitamin_d', 'breastfed', 'any_smokers', 'z_score', 'birth_weight_ctr'] # In[107]: analysis_columns = covs + viruses + diagnoses # In[134]: rsv_dataset = hospitalized.loc[hospitalized.RSV, analysis_columns+['rsv_count']] rhino_dataset = hospitalized.loc[hospitalized.Rhino, analysis_columns+['rhinovirus_count']] adeno_dataset = hospitalized.loc[hospitalized.Adeno, analysis_columns+['adenovirus_count']] hmpv_dataset = hospitalized.loc[hospitalized.HMPV, analysis_columns+['hmpv_count']] flu_dataset = hospitalized.loc[hospitalized.Influenza, analysis_columns+['influenza_a_count', 'influenza_b_count', 'influenza_c_count']] flu_dataset['flu_count'] = flu_dataset[['influenza_a_count', 'influenza_b_count', 'influenza_c_count']].fillna(0).sum(1) flu_dataset = flu_dataset.drop(['influenza_a_count', 'influenza_b_count', 'influenza_c_count'], axis=1) piv_dataset = hospitalized.loc[hospitalized.PIV, analysis_columns+['piv1_count', 'piv2_count', 'piv3_count']] piv_dataset['piv_count'] = piv_dataset[['piv1_count', 'piv2_count', 'piv3_count']].fillna(0).sum(1) # In[109]: rsv_dataset.hospitalized_vitamin_d.hist(bins=20) # In[110]: rsv_dataset.plot(y='age_years', x='hospitalized_vitamin_d', kind='scatter', alpha=0.3) # In[111]: rsv_dataset.isnull().sum() # Drop those missing z-score and birth weight # In[114]: rsv_dataset = rsv_dataset.dropna(subset=['z_score', 'birth_weight_ctr']) # ## Specify Models # RSV model, with indicators for presence of otehr viruses. # In[115]: def virus_model(dataset, virus_name, output_column, iterations=10000): use_nuts = iterations<10000 with pm.Model() as _model: dataset = dataset.dropna(subset=[output_column]) # Prepare data other_viruses = viruses[:] other_viruses.remove(virus_name) X = dataset[other_viruses +diagnoses +covs].astype(float) # Dimensions of data n, k = X.shape virus_ct = dataset[output_column].values virus_centered = virus_ct - virus_ct.mean() # Extract predictors with missing values vitamin_D = X.pop('hospitalized_vitamin_d') exp_vitamin_D = np.exp(vitamin_D) vitamin_D_norm = normalize(exp_vitamin_D).fillna(99) los = X.pop('length_of_stay').fillna(999) # Impute missing vitamin D with 2-component normal mixture if (vitamin_D_norm==99).sum(): p_high_D = pm.Dirichlet('p_high_D', np.ones(2)) ϕ = pm.Normal('ϕ', 0, sd=1000, shape=2) τ = pm.Uniform('τ', 0, 100, shape=2) vitamin_D_imp = pm.NormalMixture('vitamin_D_imp', p_high_D, ϕ, sd=τ, observed=np.ma.masked_values(vitamin_D_norm, value=99)) else: vitamin_D_imp = vitamin_D_norm if (los==999).sum(): ψ = pm.Normal('ψ', 0, sd=1000) η = pm.Uniform('η', 0, 100) los_imp = pm.Lognormal('los_imp', ψ, η, observed=np.ma.masked_values(los, value=999)) else: los_imp = los los_norm = normalize(los_imp) # Predictor coefficients β = pm.Normal('β', 0, sd=10000, shape=k) # Mean μ = pm.Normal('μ', 0, sd=10000) θ = (μ + tt.dot(X.values, β[:-2]) + β[-2]*los_norm + β[-1]*vitamin_D_imp) # Data likelihood σ = pm.HalfCauchy('σ', 5) pm.Normal('likelihood', θ, sd=σ, observed=virus_centered) if use_nuts: step = pm.NUTS(target_accept=0.99) else: metropolis_vars = [β, μ, σ] if (vitamin_D_norm==99).sum(): metropolis_vars += [vitamin_D_imp, p_high_D, ϕ, τ] if (los==999).sum(): metropolis_vars += [los_imp, ψ, η] step = pm.Metropolis(vars=metropolis_vars) trace = pm.sample(iterations, step=step, random_seed=rseed[0]) return trace # In[116]: rsv_trace = virus_model(rsv_dataset, 'RSV', 'rsv_count') # Estimates of parameter values. Positive values are associated with higher CT response. Last two parameters are hyperparameters, and should not be interpreted. # # - HMPV coinfection results in 5-10 point higher CT # - Rhinovirus coinfection results in 0.5-1.5 point higher CT # - each year of age beyond newborn associated with 1-2 point higher CT # In[117]: rsv_labels = viruses+diagnoses+covs rsv_labels.remove('RSV') # In[48]: pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates') # In[118]: pm.forestplot(rsv_trace[-1000:], varnames=['β'], ylabels=rsv_labels, main='RSV load covariates') # In[119]: pm.traceplot(rsv_trace[-1000:], varnames=['p_high_D']) # Here are the imputed vitamin D (normalized) values. # In[120]: pm.forestplot(rsv_trace[-1000:], varnames=['vitamin_D_imp_missing']) # Rhinovirus model, with indicators for presence of other viruses. # In[121]: rhino_dataset.isnull().sum() # In[122]: rhino_trace = virus_model(rhino_dataset.dropna(subset=['z_score', 'rhinovirus_count']), 'Rhino', 'rhinovirus_count') # In[123]: rhino_labels = viruses+diagnoses+covs rhino_labels.remove('Rhino') # In[124]: pm.forestplot(rhino_trace[-1000:], varnames=['β'], ylabels=rhino_labels, main='Rhinovirus load covariates') # HMPV model, with indicators for presence of other viruses. # In[125]: hmpv_dataset.isnull().sum() # In[126]: hmpv_trace = virus_model(hmpv_dataset.dropna(subset=['hmpv_count']), 'HMPV', 'hmpv_count') # In[127]: hmpv_labels = viruses+diagnoses+covs hmpv_labels.remove('HMPV') # In[128]: pm.forestplot(hmpv_trace[-1000:], varnames=['β'], ylabels=hmpv_labels, main='HMPV load covariates') # Adenovirus model, with indicators for presence of other viruses. # In[129]: adeno_dataset.isnull().sum() # In[130]: adeno_trace = virus_model(adeno_dataset.dropna(subset=['z_score']), 'Adeno', 'adenovirus_count') # In[131]: adeno_labels = viruses+diagnoses+covs adeno_labels.remove('Adeno') # In[132]: pm.forestplot(adeno_trace[-1000:], varnames=['β'], ylabels=adeno_labels, main='Adenovirus load covariates') # Influenza model, with indicators for presence of other viruses. # In[135]: flu_dataset.isnull().sum() # In[136]: flu_trace = virus_model(flu_dataset, 'Influenza', 'flu_count') # In[137]: flu_labels = viruses+diagnoses+covs flu_labels.remove('Influenza') # In[138]: pm.forestplot(flu_trace[-1000:], varnames=['β'], ylabels=flu_labels, main='Influenza load covariates') # PIV model, with indicators for presence of other viruses. # In[139]: piv_dataset.isnull().sum() # In[140]: piv_trace = virus_model(piv_dataset, 'PIV', 'piv_count') # In[141]: piv_labels = viruses+diagnoses+covs piv_labels.remove('PIV') # In[142]: pm.forestplot(piv_trace[-1000:], varnames=['β'], ylabels=piv_labels, main='PIV load covariates')