In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc as pm
import pdb

pm.__version__
Out[1]:
'2.3.5'
In [2]:
flu = pd.read_csv('data/flu_organism.csv', index_col=0)
In [3]:
flu.shape
Out[3]:
(2937, 64)
In [4]:
flu.PatientID.unique().shape
Out[4]:
(1654,)

Turn into table of unique patients, with appropriate organism columns

Create indicator for flu-only.

In [5]:
flu_only = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
                        len([x for x in s if str(x).startswith('Influenza')])==len(s)).astype(bool)
In [6]:
flu_only.head()
Out[6]:
PatientID
001B0EC6-66F5-48C9-8516-52D2ABFE28AA    False
0034181F-59AB-4B2C-A06A-A83EE1DF1A17     True
004AE387-1096-410D-A7CB-F9A9086AA385    False
0084DCF0-8E3C-4F1C-8650-BE795A57EF2F    False
0084E706-90D0-4314-8C99-496171AB639D     True
Name: OrganismName, dtype: bool
In [7]:
flu_only.name = 'flu_only'
In [8]:
flu = flu.merge(pd.DataFrame(flu_only), left_on='PatientID', right_index=True)
In [9]:
# flu_staph_organism = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
#             len([x for x in s if str(x).startswith('Influenza') or 
#                  str(x).startswith('Staphylococcus aureus')])>1).astype(bool)
In [10]:
# flu_staph_organism.name = 'flu_staph_org'
In [11]:
# flu = flu.merge(pd.DataFrame(flu_staph_organism), left_on='PatientID', right_index=True)

See if there are any flu-only or flu-other individuals with an ICD9 diagnosis of S. Aureus

In [12]:
flu_other = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
            len([x for x in s if str(x).startswith('Influenza') or 
                 str(x).startswith('Other')])>1).astype(bool)
In [13]:
flu_other.name = 'flu_other'
In [14]:
flu = flu.merge(pd.DataFrame(flu_other), left_on='PatientID', right_index=True)

There do not appear to be cases of flu-only or flu-other via organism that have a S.Aureus diagnosis via ICD9

In [15]:
(flu_other & flu.s_aureus_icd9).sum()
Out[15]:
0
In [16]:
(flu_only & flu.s_aureus_icd9).sum()
Out[16]:
0

Same for pneumo

In [17]:
(flu_other & flu.pneumo_icd9).sum()
Out[17]:
0
In [18]:
flu.groupby('PatientID')['OrganismName'].value_counts()
Out[18]:
PatientID                                                                       
001B0EC6-66F5-48C9-8516-52D2ABFE28AA  Enterococcus                                  1
                                      Klebsiella pneumoniae                         1
0034181F-59AB-4B2C-A06A-A83EE1DF1A17  Influenza A                                   1
004AE387-1096-410D-A7CB-F9A9086AA385  Influenza A                                   1
                                      Pseudomonas aerugenosa                        1
                                      Candida albicans                              1
0084DCF0-8E3C-4F1C-8650-BE795A57EF2F  Influenza A                                   1
                                      Staphylococcus aureus                         1
0084E706-90D0-4314-8C99-496171AB639D  Influenza B                                   1
00EFD11D-650A-4317-8C4F-4B7378AD7946  Other                                         2
0118B486-8ACE-4BFF-8ED6-753D96E6B441  Influenza A                                   1
011B5A5D-9967-44A7-92F9-02F82D31216D  Candida albicans                              2
                                      Influenza A                                   1
                                      Epstein-Barr virus (EBV)                      1
01AC9644-7CDF-4E28-A5F4-FE5B7F9FA22A  Streptococcus pneumoniae                      1
                                      Staphylococcus, coag neg                      1
                                      Influenza A                                   1
                                      Moraxella catarrhalis (Branhamella)           1
01F84F52-92D3-403B-B8AC-21A72A4BC771  Influenza A                                   1
02373407-74E6-4A1E-9629-F4DCBF707B0C  Influenza A                                   1
025D87C9-CF70-4B9A-94BE-B15EA2185472  Eschericha coli                               2
                                      Influenza A                                   1
                                      Enterococcus                                  1
02A94D2E-768A-4ABB-86BF-C8A188D698CC  Influenza A                                   1
02C9FF5A-A4D6-461A-A650-E3EB405D62AE  Influenza A                                   2
02D33150-FDDB-4F12-AAAB-D6F4B8765EBF  Influenza A                                   1
031C1231-B585-478F-A98F-35E017B788CD  Other                                         2
035A053E-91B1-4B68-87D4-F92C675FE6E3  Other                                         2
03AA4AD0-1B0E-4CAD-BD76-03FE52AA580D  Influenza A                                   1
                                      Epstein-Barr virus (EBV)                      1
                                                                                   ..
FC9F84B4-B2CD-4C28-B824-D5CE1AF7BEF1  Gram positive, other                          2
                                      Yeast sp.                                     1
                                      Aspergillus fumigatus                         1
                                      Pseudomonas fluorescens                       1
                                      Enterococcus                                  1
                                      Candida albicans                              1
                                      Stenotrophomonas maltophilia (Xanthomonas)    1
                                      Other                                         1
FCBAEC19-EFA1-4A95-A98D-1D8FE65B53CE  Influenza A                                   1
FCD795A0-9FD1-4C20-961B-04BEF1B23E96  Influenza A                                   1
                                      Staphylococcus aureus                         1
FCDF1967-C307-4FF1-8CF1-7AA02FE0BE47  Influenza A                                   1
FD2FA339-7417-4F13-811B-2E1873E529C7  Influenza A                                   1
FDC32D1D-0D67-4752-A2E1-72E790FCB797  Influenza A                                   1
                                      Pseudomonas aerugenosa                        1
FDED68D1-1DFF-4899-9CF9-797E6AE70D18  Influenza A                                   1
FDFF23E3-C0CB-468D-98D9-998192CA7590  Influenza A                                   1
                                      Yeast sp.                                     1
                                      Eschericha coli                               1
                                      Candida albicans                              1
FEA94F00-083A-479C-B862-05274A9842B5  Influenza A                                   1
                                      Streptococcus, group A                        1
                                      Staphylococcus aureus                         1
FF0D88BF-E131-4E56-9E78-CC6B60E461FD  Influenza A                                   1
                                      Yeast sp.                                     1
                                      Staphylococcus, coag neg                      1
FF0FB831-9804-4506-AE0B-B447AB54D4D5  Hemophilus influenza                          1
FF528500-5274-4225-96C2-CF1153EE3AB0  Other                                         1
FF6CD7E8-8956-4C15-AF6C-3033A38C48D7  Influenza A                                   1
                                      Staphylococcus epidermidis                    1
dtype: int64

Create indictors for coinfection type

In [19]:
for org in flu.Type.unique():
    flu[org] = (flu.Type==org).astype(int)

Create data frame of unique patients

In [20]:
flu_unique = flu.drop_duplicates(subset=['PatientID']).set_index('PatientID')
In [21]:
flu_unique.s_aureus_icd9.sum()
Out[21]:
100
In [22]:
flu_unique['flu_only'] = flu_only
In [23]:
flu_unique.flu_only.mean()
Out[23]:
0.34159613059250304

Several missing values for admission to time on ECMO

In [24]:
flu_unique.AdmitToTimeOnHours.isnull().mean()
Out[24]:
0.16686819830713423

Create variables for use in analysis

In [25]:
# ECMO Type
flu_unique['VA'] = flu_unique.Mode.isin(['VA', 'VV-VA'])
# Set "Other" type to NA (there are only a couple)
flu_unique.loc[flu_unique.Mode=='Other', 'VA'] = None

Create oxygen index

In [26]:
flu_unique['OI'] = flu_unique.FiO2 * flu_unique.MAP / flu_unique.PO2
In [27]:
flu_unique.OI.hist(bins=20)
Out[27]:
<matplotlib.axes._subplots.AxesSubplot at 0x10dbac6a0>
In [28]:
covariates = ['AgeYears', 'pH', 'PO2', 'VA']

Get counts of missing values

In [29]:
flu_unique[covariates].isnull().mean()
Out[29]:
AgeYears    0.000000
pH          0.107618
PO2         0.091898
VA          0.010883
dtype: float64
In [30]:
flu_unique[covariates].hist(bins=25);
In [31]:
flu_unique[flu_unique.AgeYears< 1].AgeDays.hist(bins=15)
Out[31]:
<matplotlib.axes._subplots.AxesSubplot at 0x10dbc1e80>
In [32]:
flu_unique.s_aureus_icd9.sum()
Out[32]:
100
In [33]:
(flu_unique.flu_staph_org | flu_unique.s_aureus_icd9).sum()
Out[33]:
244
In [34]:
foo = '01AC9644-7CDF-4E28-A5F4-FE5B7F9FA22A'
flu_unique.ix[foo][['flu_staph_org', 's_aureus_icd9']]
Out[34]:
flu_staph_org    False
s_aureus_icd9     True
Name: 01AC9644-7CDF-4E28-A5F4-FE5B7F9FA22A, dtype: object
In [35]:
flu_unique['flu_staph'] = (flu_unique.flu_staph_org | flu_unique.s_aureus_icd9).astype(int)
In [36]:
flu_unique.to_csv('Data/flu_unique.csv')

Logistic regression model

Fate of each patient

In [37]:
died = (flu_unique.fate=='Died').astype(int).values
In [38]:
N_obs = len(died)

Functions for standarizing and centering

In [39]:
center = lambda x: (x - x[np.isnan(x) ^ True].mean())
standardize = lambda x: center(x) / x[np.isnan(x) ^ True].std()
In [40]:
covariates
Out[40]:
['AgeYears', 'pH', 'PO2', 'VA']
In [41]:
AgeYears, pH, PO2, va = flu_unique[covariates].values.T

AgeYears_center = center(AgeYears).round(1)

# Center pH at 7.4
pH_center = pH - 7.4

# OI_std = standardize(OI)
# PO2_std = standardize(PO2)
In [42]:
no_coinfection = flu_unique.flu_only.values
with_staph = flu_unique.flu_staph.values
In [43]:
flu_unique.flu_staph.sum()
Out[43]:
244

Model

In [44]:
from pymc import MCMC, Normal, Poisson, Uniform, Lambda, Laplace, Exponential, Beta, HalfCauchy, Uninformative
from pymc import Bernoulli, Lognormal, observed, Matplot, deterministic, Gamma, stochastic
from pymc import AdaptiveMetropolis, normal_like, rnormal, potential, invlogit
from pymc.gp import *
In [45]:
# Values for missing elements, to be replaced in model
# OI_std[np.isnan(OI_std)] = 0
pH_center[np.isnan(pH_center)] = 7.111
va[np.isnan(va)] = 0.5
PO2[np.isnan(PO2)] = 20.111
In [46]:
def spline(name, x, knots, smoothing, interpolation_method='linear', pred_points=10):
    """ Generate PyMC objects for a spline model of age-specific rate

    Parameters
    ----------
    name : str
    knots : array
    x : array, points to interpolate to
    smoothing : pymc.Node, smoothness parameter for smoothing spline
    interpolation_method : str, optional, one of 'linear', 'nearest', 'zero', 'slinear', 'quadratic, 'cubic'

    Results
    -------
    Returns dict of PyMC objects, including 'gamma' (log of rate at
    knots) and 'mu_age' (age-specific rate interpolated at all age
    points)
    """
    assert np.all(np.diff(knots) > 0), 'Spline knots must be strictly increasing'

    # TODO: consider changing this prior distribution to be something more familiar in linear space
    gamma = [Normal('gamma_%s_%d'%(name,k), 0., 10.**-2, value=0.) for k in knots]
    #gamma = [mc.Uniform('gamma_%s_%d'%(name,k), -20., 20., value=-10.) for k in knots]

    import scipy.interpolate
    @deterministic(name='mu_x_%s'%name)
    def mu_x(gamma=gamma, knots=knots, x=x):
        mu = scipy.interpolate.interp1d(knots, np.exp(gamma), kind=interpolation_method, 
                                        bounds_error=False, fill_value=0.)
        return mu(x)
    
    @deterministic(name='pred_%s'%name)
    def pred(gamma=gamma, knots=knots):
        mu = scipy.interpolate.interp1d(knots, np.exp(gamma), kind=interpolation_method, 
                                        bounds_error=False, fill_value=0.)
        return mu(np.linspace(x.min(), x.max(), pred_points))

    vars = dict(gamma=gamma, mu_x=mu_x, pred=pred, x=x, knots=knots)

    if (smoothing > 0) and (smoothing < 1e10):
        #print 'adding smoothing of', smoothing
        @potential(name='smooth_mu_%s'%name)
        def smooth_gamma(gamma=gamma, knots=knots, tau=smoothing**-2):
            # the following is to include a "noise floor" so that level value
            # zero prior does not exert undue influence on age pattern
            # smoothing
            # TODO: consider changing this to an offset log normal
            gamma = np.clip(gamma, np.log(np.exp(gamma).mean()/10.), np.inf)  # only include smoothing on values within 10x of mean

            return normal_like(np.sqrt(np.sum(np.diff(gamma)**2 / np.diff(knots))), 0, tau)
        vars['smooth_gamma'] = smooth_gamma

    return vars
In [47]:
np.mean(PO2)
Out[47]:
58.167999999999999
In [48]:
plt.hist(PO2-80, bins=20)
Out[48]:
(array([ 500.,  886.,  182.,   25.,   16.,   11.,    9.,    5.,    5.,
           2.,    3.,    1.,    1.,    0.,    3.,    0.,    2.,    2.,
           0.,    1.]),
 array([ -76.  ,  -40.85,   -5.7 ,   29.45,   64.6 ,   99.75,  134.9 ,
         170.05,  205.2 ,  240.35,  275.5 ,  310.65,  345.8 ,  380.95,
         416.1 ,  451.25,  486.4 ,  521.55,  556.7 ,  591.85,  627.  ]),
 <a list of 20 Patch objects>)
In [59]:
def survival_model(age_mask, n_knots=5):

    # Imputation of missing values
    p_va = Beta('p_VA', 1, 1, value=0.5)
    va_masked = np.ma.masked_values(va[age_mask], value=0.5)
    x_va = Bernoulli('x_va', p_va, value=va_masked, observed=True)
    
    mu_pH = Normal('mu_pH', 0, 0.0001, value=7)
    sigma_pH = Uniform('sigma_pH', 0, 500, value=10)
    tau_pH = sigma_pH**-2
    pH_masked = np.ma.masked_values(pH_center[age_mask], value=7.111)
    x_pH = Normal('x_pH', mu_pH, tau_pH, value=pH_masked, observed=True)
    
    alpha_PO2 = Exponential('alpha_PO2', 100, value=0.5)
    beta_PO2 = Exponential('beta_PO2', 100, value=200)
    PO2_masked = np.ma.masked_values(PO2[age_mask], value=20.111)
    x_PO2 = Gamma('x_PO2', alpha_PO2, beta_PO2, value=PO2_masked, observed=True)
    x_PO2_std = Lambda('x_PO2_std', lambda x=x_PO2: (x - 80)/np.std(x))

    X = [x_pH, x_PO2_std, x_va, flu_unique.flu_only.values[age_mask], flu_unique.flu_staph.values[age_mask]]
        
    # Intercept for survival rate
    beta0 = Normal('beta0', 0.0, 0.001, value=0)
    # Covariates
    beta = Normal('beta', 0, 0.001, value=np.zeros(len(X)))
    
    odds = Lambda('odds', lambda b=beta: np.exp(b))

    if AgeYears[age_mask].max()>1:
        age_knots = np.linspace(AgeYears_center[age_mask].min(), AgeYears_center[age_mask].max(), n_knots)
        α_age = Exponential('α_age', 1, value=0.1)
        spline_age = spline('spline_age', AgeYears_center[age_mask], age_knots, α_age)

        # Event rates
        @deterministic
        def π(b0=beta0, b=beta, x=X, γ=spline_age['mu_x']):
            return invlogit(b0 + np.dot(np.transpose(x), b) + γ)
    
    else:
        # Event rates
        @deterministic
        def π(b0=beta0, b=beta, x=X):
            return invlogit(b0 + np.dot(np.transpose(x), b))
    
    deaths = Bernoulli('deaths', π, value=died[age_mask], observed=True)
    
    return locals()

Adults model

In [60]:
adult_mask = AgeYears>=18
In [61]:
M_adult = MCMC(survival_model(adult_mask))
In [62]:
M_adult.sample(100000, 90000)
 [-----------------100%-----------------] 100000 of 100000 complete in 332.3 sec
In [63]:
cov_labels = ['pH', 'PO2', 'VA', 'flu only', 'flu + staph']
In [64]:
Matplot.summary_plot(M_adult.odds, custom_labels=cov_labels, hpd=False, vline_pos=1)
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [65]:
M_adult.odds.summary()
odds:
 
	Mean             SD               MC Error        95% HPD interval
	------------------------------------------------------------------
	0.086            0.038            0.003            [ 0.029  0.162]
	0.927            0.069            0.003            [ 0.795  1.065]
	2.825            0.588            0.036            [ 1.795  4.004]
	1.002            0.151            0.008            [ 0.722  1.295]
	1.069            0.22             0.012            [ 0.625  1.472]
	
	
	Posterior quantiles:
	
	2.5             25              50              75             97.5
	 |---------------|===============|===============|---------------|
	0.034            0.056           0.078          0.106         0.18
	0.789            0.88            0.927          0.974         1.06
	1.894            2.4             2.759          3.173         4.191
	0.734            0.896           0.993          1.1           1.322
	0.68             0.916           1.05           1.206         1.55
	
In [66]:
Matplot.summary_plot(M_adult.__dict__['spline_age']['pred'], 
                     custom_labels=np.linspace(AgeYears[adult_mask].min(), AgeYears[adult_mask].max(), 10).round(1).astype(str),
                    xlab='Effect size', main='Age effect')
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.

Child model

In [67]:
child_mask = (AgeYears>=1) & (AgeYears<18)

M_child = MCMC(survival_model(child_mask))
M_child.sample(100000, 90000)
 [-----------------100%-----------------] 100000 of 100000 complete in 253.8 sec
In [68]:
Matplot.summary_plot(M_child.odds, custom_labels=cov_labels, 
                     hpd=False, vline_pos=1)
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [69]:
Matplot.summary_plot(M_child.__dict__['spline_age']['pred'], 
                     custom_labels=np.linspace(AgeYears[child_mask].min(), AgeYears[child_mask].max(), 10).round(1).astype(str),
                    xlab='Effect size', main='Age effect')
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [70]:
M_child.odds.summary()
odds:
 
	Mean             SD               MC Error        95% HPD interval
	------------------------------------------------------------------
	0.554            0.369            0.029            [ 0.068  1.244]
	0.857            0.108            0.004            [ 0.653  1.075]
	1.525            0.331            0.018            [ 0.888  2.138]
	0.947            0.226            0.011            [ 0.549  1.435]
	2.513            0.814            0.053            [ 1.173  4.082]
	
	
	Posterior quantiles:
	
	2.5             25              50              75             97.5
	 |---------------|===============|===============|---------------|
	0.152            0.319           0.458          0.687         1.548
	0.653            0.784           0.855          0.927         1.078
	0.997            1.3             1.485          1.715         2.33
	0.57             0.788           0.922          1.077         1.476
	1.311            1.956           2.378          2.904         4.526
	

Infant model

In [71]:
infant_mask = (AgeYears<1)

M_infant = MCMC(survival_model(infant_mask))
M_infant.sample(100000, 90000)
 [-----------------100%-----------------] 100000 of 100000 complete in 57.5 sec
In [72]:
Matplot.summary_plot(M_infant.odds, custom_labels=cov_labels, hpd=False, vline_pos=1)
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [73]:
M_infant.odds.summary()
odds:
 
	Mean             SD               MC Error        95% HPD interval
	------------------------------------------------------------------
	1.344            2.017            0.134            [ 0.034  4.176]
	1.128            0.246            0.013            [ 0.67   1.624]
	3.313            1.821            0.133            [ 0.833  6.437]
	1.599            0.63             0.034            [ 0.606  2.866]
	5.411            4.469            0.333          [  0.509  14.367]
	
	
	Posterior quantiles:
	
	2.5             25              50              75             97.5
	 |---------------|===============|===============|---------------|
	0.093            0.409           0.784          1.461         6.261
	0.725            0.959           1.092          1.262         1.719
	1.247            2.144           2.837          4.017         7.654
	0.739            1.153           1.486          1.871         3.276
	1.169            2.625           4.044          6.681         19.302