#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import pandas as pd import matplotlib.pyplot as plt import pymc3 as pm import pdb pm.__version__ # In[2]: flu = pd.read_csv('data/flu_organism.csv', index_col=0) # In[3]: flu.shape # In[4]: flu.PatientID.unique().shape # ### Turn into table of unique patients, with appropriate organism columns # # Create indicator for flu-only. # In[5]: flu_only = flu.groupby('PatientID')['OrganismName'].apply(lambda s: len([x for x in s if str(x).startswith('Influenza')])==len(s)).astype(bool) # In[6]: flu_only.head() # In[7]: flu_only.name = 'flu_only' # In[8]: flu = flu.merge(pd.DataFrame(flu_only), left_on='PatientID', right_index=True) # In[13]: flu_staph_organism = flu.groupby('PatientID')['OrganismName'].apply(lambda s: len([x for x in s if str(x).startswith('Influenza') or str(x).startswith('Staphylococcus aureus')])>1).astype(bool) # In[14]: flu_staph_organism.name = 'flu_staph_org' # In[15]: flu = flu.merge(pd.DataFrame(flu_staph_organism), left_on='PatientID', right_index=True) # In[12]: flu_staph_organism.mean() # See if there are any flu-only or flu-other individuals with an ICD9 diagnosis of S. Aureus # In[18]: flu_other = flu.groupby('PatientID')['OrganismName'].apply(lambda s: len([x for x in s if str(x).startswith('Influenza') or str(x).startswith('Other')])>1).astype(bool) # In[19]: flu_other.name = 'flu_other' # In[20]: flu = flu.merge(pd.DataFrame(flu_other), left_on='PatientID', right_index=True) # There do not appear to be cases of flu-only or flu-other via organism that have a S.Aureus diagnosis via ICD9 # In[21]: (flu_other & flu.s_aureus_icd9).sum() # In[22]: (flu_only & flu.s_aureus_icd9).sum() # Same for pneumo # In[23]: (flu_other & flu.pneumo_icd9).sum() # In[24]: flu.groupby('PatientID')['OrganismName'].value_counts() # Create indictors for coinfection type # In[25]: for org in flu.Type.unique(): flu[org] = (flu.Type==org).astype(int) # Create data frame of unique patients # In[26]: flu_unique = flu.drop_duplicates(subset=['PatientID']).set_index('PatientID') # In[27]: flu_unique.s_aureus_icd9.sum() # In[28]: flu_unique['flu_only'] = flu_only # In[29]: flu_unique.flu_only.mean() # In[23]: flu_unique.flu_staph.mean() # Several missing values for admission to time on ECMO # In[30]: flu_unique.AdmitToTimeOnHours.isnull().mean() # Since we need this field to calculate event time, we will have to drop individuals with missing values from the survival analysis. # In[31]: flu_unique = flu_unique.dropna(subset=['AdmitToTimeOnHours', 'HoursECMO']) assert not flu_unique.AdmitToTimeOnHours.isnull().sum() # Create variables for use in analysis # In[45]: # ECMO Type flu_unique['VA'] = flu_unique.Mode.isin(['VA', 'VV-VA']) # Set "Other" type to NA (there are only a couple) flu_unique.loc[flu_unique.Mode=='Other', 'VA'] = None # Create oxygen index # In[46]: flu_unique['OI'] = flu_unique.FiO2 * flu_unique.MAP / flu_unique.PO2 # In[47]: flu_unique.OI.hist(bins=20) # In[48]: covariates = ['AgeYears', 'pH', 'OI', 'VA'] # Get counts of missing values # In[49]: flu_unique[covariates].isnull().mean() # In[50]: flu_unique[covariates].hist(bins=25); # In[51]: flu_unique[flu_unique.AgeYears< 1].AgeDays.hist(bins=15) # In[52]: flu_unique.flu_staph_org.sum() # In[53]: flu_unique.s_aureus_icd9.sum() # In[55]: (flu_unique.flu_staph_org | flu_unique.s_aureus_icd9).sum() # ## Logistic regression model # Fate of each patient # In[35]: died = (flu_complete.fate=='Died').astype(int).values # In[36]: N_obs = len(died) # Functions for standarizing and centering # In[37]: center = lambda x: (x - x[np.isnan(x) ^ True].mean()) standardize = lambda x: center(x) / x[np.isnan(x) ^ True].std() # In[38]: covariates # In[39]: AgeYears, pH, OI, va = flu_complete[covariates].values.T AgeYears_center = center(AgeYears).round(1) # Center pH at 7.4 pH_center = pH - 7.4 OI_std = standardize(OI) # In[63]: no_coinfection = flu_complete.flu_only.values with_staph = flu_complete.flu_staph.values # In[101]: flu_unique.flu_staph.sum() # ## Model # In[40]: # Values for missing elements, to be replaced in model OI_std[np.isnan(OI_std)] = -999 pH_center[np.isnan(pH_center)] = -999 va[np.isnan(va)] = -999 # In[90]: from pymc3 import Model, find_MAP, NUTS, Metropolis, sample, forestplot, traceplot, Slice from pymc3 import Normal, HalfCauchy, Beta, Bernoulli, MvNormal from pymc3.distributions.timeseries import GaussianRandomWalk from numpy.ma import masked_values # In[56]: def interpolate(x0, y0, x): idx = np.searchsorted(x0, x) dl = np.array(x - x0[idx - 1]) dr = np.array(x0[idx] - x) d = dl + dr wl = dr / d return wl*y0[idx-1] + (1-wl)*y0[idx] # In[60]: import theano.tensor as T def invlogit(x): return 1. / (1 + T.exp(-x)) # In[86]: age_d = np.array([[xi - yi for yi in AgeYears_center] for xi in AgeYears_center]) # In[97]: with Model() as model: # Impute missing values μ_pH = Normal('μ_pH', 0, 0.001, testval=0) σ_pH = HalfCauchy('σ_pH', 5, testval=1) pH_imputed = Normal('pH_imputed', μ_pH, σ_pH**-2, observed=masked_values(pH_center, value=-999)) p_va = Beta('p_va', 1, 1) va_imputed = Bernoulli('va_imputed', p_va, observed=masked_values(va, value=-999)) # Non-linear age effect ''' nknots = 10 knots = np.linspace(AgeYears_center.min(), AgeYears_center.max(), nknots) age_sd = HalfCauchy('age_sd', 5, testval=1) age_knots = GaussianRandomWalk('age_knots', sd=age_sd, shape=nknots) age_effects = interpolate(knots, age_knots, AgeYears_center) ''' # Parameters of kernel function η2 = pm.HalfCauchy('η2', 5, testval=1) ρ2 = pm.HalfCauchy('ρ2', 5, testval=1) σ2 = pm.HalfCauchy('σ2', 5, testval=1) # Construct precision matrix M = η2 * T.exp(-ρ2 * age_d**2) S = M + T.identity_like(M)*σ2 τ = T.nlinalg.matrix_inverse(S) # GP of age effect age_effects = MvNormal('age_effects', np.zeros(N_obs), τ, shape=N_obs) μ = Normal('μ', 0, 0.01, testval=0) # Linear covariates β = Normal('β', 0, 0.01, shape=4) # Probability of mortality effect π = invlogit(μ + β[0]*no_coinfection + β[1]*with_staph + β[2]*pH_imputed + β[3]*va_imputed + age_effects) deaths = Bernoulli('deaths', π, observed=died) # In[98]: with model: start = find_MAP() #step1 = NUTS([β, μ_pH, σ_pH, p_va, age_sd, age_knots, μ], scaling=start) step1 = NUTS([β, μ_pH, σ_pH, p_va, η2, ρ2, σ2, μ], scaling=start) step2 = Slice([va_imputed.missing_values, pH_imputed.missing_values]) trace = sample(5000, [step1, step2], start) # In[76]: forestplot(trace, vars=['β'], ylabels=['no coinfection', 'staph coinfection', 'pH', 'VA']) # In[78]: traceplot(trace, vars=['age_sd', 'σ_pH']) # In[ ]: