In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc as pm
import pdb

pm.__version__
Out[1]:
'2.3.5'
In [2]:
flu = pd.read_csv('data/flu_organism.csv', index_col=0)
In [3]:
flu.shape
Out[3]:
(1450, 61)
In [4]:
flu.PatientID.unique().shape
Out[4]:
(922,)

Turn into table of unique patients, with appropriate organism columns

Create indicator for flu-only.

In [5]:
flu_only = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
                        len([x for x in s if str(x).startswith('Influenza')])==len(s)).astype(bool)
In [6]:
flu_only.head()
Out[6]:
PatientID
0034181F-59AB-4B2C-A06A-A83EE1DF1A17     True
0084E706-90D0-4314-8C99-496171AB639D     True
00EFD11D-650A-4317-8C4F-4B7378AD7946    False
0118B486-8ACE-4BFF-8ED6-753D96E6B441     True
014A4039-B817-4E96-9FB0-7034F341505B    False
Name: OrganismName, dtype: bool
In [7]:
flu_only.name = 'flu_only'
In [8]:
flu = flu.merge(pd.DataFrame(flu_only), left_on='PatientID', right_index=True)
In [9]:
flu_staph_organism = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
            len([x for x in s if str(x).startswith('Influenza') or 
                 str(x).startswith('Staphylococcus aureus')])>1).astype(bool)
In [10]:
flu_staph_organism.name = 'flu_staph'
In [11]:
flu = flu.merge(pd.DataFrame(flu_staph_organism), left_on='PatientID', right_index=True)
In [12]:
flu_staph_organism.mean()
Out[12]:
0.10845986984815618

See if there are any flu-only or flu-other individuals with an ICD9 diagnosis of S. Aureus

In [13]:
flu_other = flu.groupby('PatientID')['OrganismName'].apply(lambda s: 
            len([x for x in s if str(x).startswith('Influenza') or 
                 str(x).startswith('Other')])>1).astype(bool)
In [14]:
flu_other.name = 'flu_other'
In [15]:
flu = flu.merge(pd.DataFrame(flu_other), left_on='PatientID', right_index=True)

There do not appear to be cases of flu-only or flu-other via organism that have a S.Aureus diagnosis via ICD9

In [16]:
(flu_other & flu.s_aureus_icd9).sum()
Out[16]:
0
In [17]:
(flu_only & flu.s_aureus_icd9).sum()
Out[17]:
0
In [18]:
flu.groupby('PatientID')['OrganismName'].value_counts()
Out[18]:
PatientID                                                                       
0034181F-59AB-4B2C-A06A-A83EE1DF1A17  Influenza A                                   1
0084E706-90D0-4314-8C99-496171AB639D  Influenza B                                   1
00EFD11D-650A-4317-8C4F-4B7378AD7946  Other                                         2
0118B486-8ACE-4BFF-8ED6-753D96E6B441  Influenza A                                   1
02C9FF5A-A4D6-461A-A650-E3EB405D62AE  Influenza A                                   2
031C1231-B585-478F-A98F-35E017B788CD  Other                                         2
035A053E-91B1-4B68-87D4-F92C675FE6E3  Other                                         2
03AA4AD0-1B0E-4CAD-BD76-03FE52AA580D  Staphylococcus, coag neg                      1
                                      Epstein-Barr virus (EBV)                      1
                                      Hemophilus parainfluenzae                     1
                                      Influenza A                                   1
                                      Pneumocystis carinii                          1
03FA6416-FE28-484B-AC83-60A7442035BE  Influenza A                                   1
040710D1-2D29-48F4-B14C-ABAD7F2DC3B4  Influenza A                                   1
04B6544B-76BD-475E-83B4-782E2387093C  Influenza A                                   1
                                      Candida albicans                              1
04BB440E-3A59-4453-8648-72AA1FB7CE22  Influenza A                                   1
057101E0-8C51-441B-AA05-B52D0092D820  Influenza A                                   1
0630B500-DDCD-4882-B3F2-D4B6DCD0DE59  Enterovirus                                   1
063CA301-7E8A-47EE-B8E3-6F01E734FEB7  Other                                         2
06522594-1B6F-492D-B47C-D94460052212  Influenza A                                   1
06799A45-93D6-4AAB-AA45-02DACB2C0D27  Influenza A                                   1
06AAFC52-B156-41FD-83C4-B2F6FC77E70A  Acinetobacter sp.                             3
07233E93-332D-4766-9696-A352A852856D  Candida albicans                              2
07ACDC42-6DCB-48D1-8458-F31097B95330  Influenza A                                   1
                                      Streptococcus pneumoniae                      1
07E15E96-706B-4DA1-899B-E5A9213FC864  Influenza A                                   1
07E64A9C-0C5D-4D56-AB4D-BDA56B8F7E2E  Other                                         1
                                      Candida albicans                              1
081ABE0C-3969-496A-B9DA-42536D65299E  Enterovirus                                   1
                                                                                   ..
F8C7118D-C465-4065-8C7D-3DDCE71F6D28  Candida parapsilosis                          1
F9BD3113-6BA4-4417-815E-F7F26B0C3E52  Influenza A                                   1
F9C6E182-48BA-4424-903A-47B59957F1EF  Influenza A                                   1
F9E79DA1-74ED-4036-ABD2-BB7E24684F36  Gram negative, other                          1
F9EE26A2-5A7B-4CB1-8EC4-F56AED36A623  Other                                         2
FAD34086-3A25-4CC2-B1EC-D69FA130446B  Influenza A                                   1
FAE1C068-87EA-4442-9493-13B1621DAE32  Influenza B                                   1
FAF53806-357F-4143-AC8C-7436C6A34D66  Influenza A                                   1
                                      Staphylococcus aureus                         1
FB45A43B-15A3-4F3C-8166-A94B1D943600  Other                                         2
FB6D570C-5150-448F-A0D7-F4D23B1D3A4A  Other                                         2
FB995686-3063-4D23-AF6B-0E69BF18A09F  Influenza A                                   1
FBCDE028-006C-4950-AB37-E1EB88590EEB  Influenza A                                   1
FBD37823-BF97-4418-9733-778E8A917160  Other                                         2
FC9F84B4-B2CD-4C28-B824-D5CE1AF7BEF1  Gram positive, other                          2
                                      Other                                         1
                                      Pseudomonas fluorescens                       1
                                      Candida albicans                              1
                                      Aspergillus fumigatus                         1
                                      Yeast sp.                                     1
                                      Enterococcus                                  1
                                      Stenotrophomonas maltophilia (Xanthomonas)    1
FCBAEC19-EFA1-4A95-A98D-1D8FE65B53CE  Influenza A                                   1
FCD795A0-9FD1-4C20-961B-04BEF1B23E96  Influenza A                                   1
                                      Staphylococcus aureus                         1
FD2FA339-7417-4F13-811B-2E1873E529C7  Influenza A                                   1
FDC32D1D-0D67-4752-A2E1-72E790FCB797  Influenza A                                   1
                                      Pseudomonas aerugenosa                        1
FF0FB831-9804-4506-AE0B-B447AB54D4D5  Hemophilus influenza                          1
FF528500-5274-4225-96C2-CF1153EE3AB0  Other                                         1
dtype: int64

Create indictors for coinfection type

In [19]:
for org in flu.Type.unique():
    flu[org] = (flu.Type==org).astype(int)

Create data frame of unique patients

In [20]:
flu_unique = flu.drop_duplicates(subset=['PatientID']).set_index('PatientID')
In [21]:
flu_unique['flu_only'] = flu_only
In [22]:
flu_unique.flu_only.mean()
Out[22]:
0.35791757049891543
In [23]:
flu_unique.flu_staph.mean()
Out[23]:
0.10845986984815618

Several missing values for admission to time on ECMO

In [24]:
flu_unique.AdmitToTimeOnHours.isnull().mean()
Out[24]:
0.13665943600867678

Since we need this field to calculate event time, we will have to drop individuals with missing values from the survival analysis.

In [25]:
flu_unique = flu_unique.dropna(subset=['AdmitToTimeOnHours', 'HoursECMO'])
assert not flu_unique.AdmitToTimeOnHours.isnull().sum()

Add time from admission to time on ECMO to get total exposure time.

In [26]:
admit_through_emco = flu_unique.AdmitToTimeOnHours.add(flu_unique.HoursECMO)
flu_unique['admit_through_emco'] = admit_through_emco

See how many have no time information from admission, and drop these

In [27]:
missing = admit_through_emco.isnull()
flu_complete = flu_unique[missing ^ True]
admit_through_emco.dropna(inplace=True)
assert not admit_through_emco.isnull().sum()

Time to event

Time off ECMO through to event (death or discharge). Fill with zeros for TimeOffToDeathDateHours, implying censoring.

In [28]:
flu_unique.TimeOffToDCDateHours.add(flu_unique.TimeOffToDeathDateHours, fill_value=0).isnull().sum()
Out[28]:
45
In [29]:
off_ecmo_to_event = flu_complete.TimeOffToDCDateHours.add(flu_complete.TimeOffToDeathDateHours, 
                                                          fill_value=0)
In [30]:
off_ecmo_to_event.isnull().sum()
Out[30]:
45

Time to event is admit throught ecmo, plus off ecmo to event times.

In [35]:
flu_complete['time_to_event'] = admit_through_emco + off_ecmo_to_event.fillna(0)
assert not flu_complete.time_to_event.isnull().sum()

Create variables for use in analysis

In [36]:
# ECMO Type
flu_complete['VA'] = flu_complete.Mode.isin(['VA', 'VV-VA'])
# Set "Other" type to NA (there are only a couple)
flu_complete.loc[flu_complete.Mode=='Other', 'VA'] = None

Create oxygen index

In [37]:
flu.columns
Out[37]:
Index([                   'PatientID',                        'RunNo',
                            'AgeDays',                    'HoursECMO',
                        'SupportType',                    'PrimaryDx',
                               'Mode',              'Discontinuation',
                    'DischargedAlive',            'DischargeLocation',
                           'YearECLS',                     'VentType',
                               'Rate',                         'FiO2',
                                'PIP',                         'PEEP',
                                'MAP',                  'HandBagging',
                                 'pH',                         'PCO2',
                                'PO2',                         'HCO3',
                               'SaO2',                   'Venttype24',
                             'Rate24',                       'Fio224',
                              'PIP24',                       'PEEP24',
                              'MAP24',                'Handbagging24',
                                'SBP',                          'DBP',
                            'MapHemo',                         'SVO2',
                               'PCWP',                         'SPAP',
                               'DPAP',                         'MPAP',
                                 'CI',                         'Race',
                                'Sex',           'AdmitToTimeOnHours',
       'TimeOffToExtubationDateHours',      'TimeOffToDeathDateHours',
               'TimeOffToDCDateHours',      'ExtubationToDCDateHours',
         'ExtubationToDeathDateHours',                         'year',
                          'HoursVent',                     'AgeYears',
                          'age_class',                        'RunID',
                         'OrganismNo',                 'OrganismName',
                        'CultureSite',     'CultureTimeIsApproximate',
                     'OrganismTiming',                         'Type',
                            'has_flu',                         'fate',
                      's_aureus_icd9',                     'flu_only',
                          'flu_staph',                    'flu_other',
                              'viral',                    'bacterial',
                                  nan,                       'fungal'],
      dtype='object')
In [38]:
flu_complete['OI'] = flu_complete.FiO2 * flu_complete.MAP / flu_complete.PO2
In [39]:
flu_complete.OI.hist(bins=20)
Out[39]:
<matplotlib.axes._subplots.AxesSubplot at 0x10a3b3390>
In [40]:
covariates = ['AgeYears', 'pH', 'OI', 'VA']

Get counts of missing values

In [41]:
flu_complete[covariates].isnull().mean()
Out[41]:
AgeYears    0.000000
pH          0.073140
OI          0.407314
VA          0.016393
dtype: float64
In [42]:
flu_complete[covariates].hist(bins=25);
In [43]:
flu_complete[flu_complete.AgeYears< 1].AgeDays.hist(bins=15)
Out[43]:
<matplotlib.axes._subplots.AxesSubplot at 0x10ab7ad68>

Survival Model

Surivival time, in days

In [44]:
obs_t = flu_complete.time_to_event.values/24.
In [45]:
plt.hist(obs_t, bins=25);

Fate of each patient

In [46]:
died = (flu_complete.fate=='Died').astype(int).values

Unique failure times

In [47]:
times = np.unique((flu_complete.time_to_event[flu_complete.fate=='Died'].values/24).astype(int))
times
Out[47]:
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  43,  44,  45,  46,  47,  48,  50,  53,  54,  55,
        59,  61,  62,  63,  67,  68,  73,  77, 108, 121, 124, 125, 126,
       133, 142, 150, 201, 227, 235])
In [48]:
N_obs = len(obs_t)
T = len(times) - 1

Calculate risk set (the number at risk at each event time)

In [49]:
Y = np.array([[int(obs >= t) for t in times] for obs in obs_t])
Y
Out[49]:
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ..., 
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]])

Counting process. Jump = 1 if $\text{obs}_t \in [ t_j, t_{j+1} )$

In [50]:
dN = np.array([[Y[i,j]*(times[j+1] >= obs_t[i])*died[i] for i in range(N_obs)] for j in range(T)])
In [51]:
# Sample process for one patient
dN[:,np.random.randint(len(dN))]
Out[51]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0])

Functions for standarizing and centering

In [52]:
center = lambda x: (x - x[np.isnan(x) ^ True].mean())
standardize = lambda x: center(x) / x[np.isnan(x) ^ True].std()
In [53]:
covariates
Out[53]:
['AgeYears', 'pH', 'OI', 'VA']
In [54]:
AgeYears, pH, OI, va = flu_complete[covariates].values.T

AgeYears_center = center(AgeYears).round(1)

# Center pH at 7.4
pH_center = pH - 7.4

OI_std = standardize(OI)

Survival curves

In [55]:
from lifelines import KaplanMeierFitter

km = KaplanMeierFitter()
km.fit(obs_t, event_observed=died)
ax = km.plot()
In [56]:
flu_only_ind = flu_complete.flu_only.values
flu_staph_ind = flu_complete.flu_staph.values
flu_other_ind = flu_complete.flu_other.values
In [57]:
km = KaplanMeierFitter()
km.fit(obs_t[flu_staph_ind], event_observed=died[flu_staph_ind], label='flu+staph')
ax = km.plot()

km.fit(obs_t[flu_only_ind], event_observed=died[flu_only_ind], label='flu only')
km.plot(ax=ax)

km.fit(obs_t[flu_other_ind], event_observed=died[flu_other_ind], label='flu+other')
km.plot(ax=ax)
Out[57]:
<matplotlib.axes._subplots.AxesSubplot at 0x10af54668>
In [58]:
age_complete = flu_complete.AgeYears.values
infant_ind = age_complete<1
kids_ind = (age_complete>=1) & (age_complete<19)
adult_ind = age_complete>=19
In [59]:
km = KaplanMeierFitter()
km.fit(obs_t[infant_ind], event_observed=died[infant_ind], label='infants')
ax = km.plot()

km.fit(obs_t[kids_ind], event_observed=died[kids_ind], label='children')
km.plot(ax=ax)

km.fit(obs_t[adult_ind], event_observed=died[adult_ind], label='adults')
km.plot(ax=ax)
Out[59]:
<matplotlib.axes._subplots.AxesSubplot at 0x10ac7f6a0>
In [60]:
va_ind = flu_complete.VA.astype(bool).values
In [61]:
km = KaplanMeierFitter()
km.fit(obs_t[va_ind], event_observed=died[va_ind], label='VA')
ax = km.plot()

km.fit(obs_t[~va_ind], event_observed=died[~va_ind], label='VV')
km.plot(ax=ax)
Out[61]:
<matplotlib.axes._subplots.AxesSubplot at 0x10ad96dd8>

Model

In [95]:
from pymc import MCMC, Normal, Poisson, Uniform, Lambda, Laplace, Exponential, Beta, HalfCauchy, Uninformative
from pymc import Bernoulli, Lognormal, observed, Matplot, deterministic, Gamma, stochastic
from pymc import AdaptiveMetropolis, normal_like, rnormal, potential
from pymc.gp import *
In [96]:
# Values for missing elements, to be replaced in model
OI_std[np.isnan(OI_std)] = 0
pH_center[np.isnan(pH_center)] = 7.111
va[np.isnan(va)] = 0.5
In [97]:
gaussian_rbf = lambda x, mu, l: np.exp(-np.abs(x - mu)**2 / l**2)
In [98]:
def rbf(name, x, n=10):
    
    # Number of basis functions for age effect
    h = np.linspace(x.min(), x.max(), n)
    l = HalfCauchy('l_%s' % name, 0, 5, value=1)
    
    # RBF weights for all the points, for all the functions
    w = Lambda('w_%s' % name, lambda l=l: np.array([gaussian_rbf(i, h, l) for i in x]))
    
    τ = Exponential('τ_%s' % name, 1, value=np.ones(n))

    # Coefficients
    β = Normal('β_%s' % name, 0, τ, value=np.zeros(n))
    
    τ_rbf = Exponential('τ_rbf_%s' % name, 1, value=1)
    
    @stochastic(name='rbf_%s' % name)
    def rbf(value=np.zeros(len(x)), β=β, w=w, τ=τ_rbf):
        m = w.dot(β)
        return normal_like(value, m, τ)
    
    @deterministic(name='rbf_eval_%s' % name)
    def rbf_eval(β=β, τ=τ_rbf, l=l):
        
        wts = np.array([gaussian_rbf(i, h, l) for i in h])
        m = wts.dot(β)
        return rnormal(m, τ)
    
    return locals()
In [134]:
def spline(name, x, knots, smoothing, interpolation_method='linear'):
    """ Generate PyMC objects for a spline model of age-specific rate

    Parameters
    ----------
    name : str
    knots : array
    x : array, points to interpolate to
    smoothing : pymc.Node, smoothness parameter for smoothing spline
    interpolation_method : str, optional, one of 'linear', 'nearest', 'zero', 'slinear', 'quadratic, 'cubic'

    Results
    -------
    Returns dict of PyMC objects, including 'gamma' (log of rate at
    knots) and 'mu_age' (age-specific rate interpolated at all age
    points)
    """
    assert np.all(np.diff(knots) > 0), 'Spline knots must be strictly increasing'

    # TODO: consider changing this prior distribution to be something more familiar in linear space
    gamma = [Normal('gamma_%s_%d'%(name,k), 0., 10.**-2, value=0.) for k in knots]
    #gamma = [mc.Uniform('gamma_%s_%d'%(name,k), -20., 20., value=-10.) for k in knots]

    import scipy.interpolate
    @deterministic(name='mu_x_%s'%name)
    def mu_x(gamma=gamma, knots=knots, x=x):
        mu = scipy.interpolate.interp1d(knots, np.exp(gamma), kind=interpolation_method, 
                                        bounds_error=False, fill_value=0.)
        return mu(x)

    vars = dict(gamma=gamma, mu_x=mu_x, x=x, knots=knots)

    if (smoothing > 0) and (smoothing < 1e10):
        #print 'adding smoothing of', smoothing
        @potential(name='smooth_mu_%s'%name)
        def smooth_gamma(gamma=gamma, knots=knots, tau=smoothing**-2):
            # the following is to include a "noise floor" so that level value
            # zero prior does not exert undue influence on age pattern
            # smoothing
            # TODO: consider changing this to an offset log normal
            gamma = np.clip(gamma, np.log(np.exp(gamma).mean()/10.), np.inf)  # only include smoothing on values within 10x of mean

            return normal_like(np.sqrt(np.sum(np.diff(gamma)**2 / np.diff(knots))), 0, tau)
        vars['smooth_gamma'] = smooth_gamma

    return vars
In [135]:
def exponential_survival_model():
    
    # Imputation of missing values
    p_va = Beta('p_VA', 1, 1, value=0.5)
    va_masked = np.ma.masked_values(va, value=0.5)
    x_va = Bernoulli('x_va', p_va, value=va_masked, observed=True)
    
    mu_pH = Normal('mu_pH', 0, 0.0001, value=7)
    sigma_pH = Uniform('sigma_pH', 0, 500, value=10)
    tau_pH = sigma_pH**-2
    pH_masked = np.ma.masked_values(pH_center, value=7.111)
    x_pH = Normal('x_pH', mu_pH, tau_pH, value=pH_masked, observed=True)
    
#     mu_OI = Normal('mu_OI', 0, 0.0001, value=0)
#     sigma_OI = Uniform('sigma_OI', 0, 500, value=10)
#     tau_OI = sigma_OI**-2
#     OI_masked = np.ma.masked_values(OI_std, value=0)
#     x_OI = Normal('x_OI', mu_OI, tau_OI, value=OI_masked, observed=True)
    
    X = [x_va, flu_unique.flu_only, flu_unique.flu_staph]
        
    # Intercept for survival rate
    beta0 = Normal('beta0', 0.0, 0.001, value=0)
    # Covariates
    beta = Normal('beta', 0, 0.001, value=np.zeros(len(X)))

    age_knots = np.linspace(AgeYears_center.min(), AgeYears_center.max(), 7)
    α_age = Exponential('α_age', 1, value=0.1)
    spline_age = spline('spline_age', AgeYears_center, age_knots, α_age)
    
    pH_knots = np.linspace(pH_center.min(), pH_center.max(), 7)
    α_pH = Exponential('α_pH', 1, value=0.1)
    spline_pH = spline('spline_pH', pH_center, pH_knots, α_pH)
    
    # Survival rates
    @deterministic
    def lam(b0=beta0, b=beta, x=X, γ=spline_age['mu_x'], ρ=spline_pH['mu_x']):
        return np.exp(b0 + np.dot(np.transpose(x), b) + γ + ρ)

    @observed
    def survival(value=obs_t, lam=lam, f=died):
        """Exponential survival likelihood, accounting for censoring"""
        return (f*np.log(lam) - lam*value).sum()

    return locals()
In [136]:
M = MCMC(exponential_survival_model())
In [137]:
iterations = 50000
burn = 40000
In [138]:
M.sample(iterations, burn)
 [-----------------100%-----------------] 50000 of 50000 complete in 453.6 sec
In [139]:
M.sample(iterations, burn)
 [-----------------100%-----------------] 50000 of 50000 complete in 445.2 sec
In [140]:
Matplot.summary_plot(M.beta, custom_labels=['VA', 'flu only', 'flu + staph'], hpd=False)
In [141]:
Matplot.summary_plot(M.__dict__['spline_pH']['gamma'], 
                     custom_labels=np.linspace(flu_unique.pH.dropna().min(), 
                                               flu_unique.pH.dropna().max(), 7).round(1).astype(str),
                    xlab='Effect size', main='pH effect')
In [142]:
Matplot.summary_plot(M.__dict__['spline_age']['gamma'], 
                     custom_labels=np.linspace(AgeYears.min(), AgeYears.max(), 7).round(1).astype(str),
                    xlab='Effect size', main='Age effect')

Same model above, but with a Weibull hazard function

In [143]:
def weibull_survival_model():
    
    # Imputation of missing values
    p_va = Beta('p_VA', 1, 1, value=0.5)
    va_masked = np.ma.masked_values(va, value=0.5)
    x_va = Bernoulli('x_va', p_va, value=va_masked, observed=True)
    
    mu_pH = Normal('mu_pH', 0, 0.0001, value=7)
    sigma_pH = Uniform('sigma_pH', 0, 500, value=10)
    tau_pH = sigma_pH**-2
    pH_masked = np.ma.masked_values(pH_center, value=7.111)
    x_pH = Normal('x_pH', mu_pH, tau_pH, value=pH_masked, observed=True)

    X = [x_va, flu_unique.flu_only, flu_unique.flu_staph]
        
    # Intercept for survival rate
    beta0 = Normal('beta0', 0.0, 0.001, value=0)
    # Covariates
    beta = Normal('beta', 0, 0.001, value=np.zeros(len(X)))

    age_knots = np.linspace(AgeYears_center.min(), AgeYears_center.max(), 7)
    α_age = Exponential('α_age', 1, value=0.1)
    spline_age = spline('spline_age', AgeYears_center, age_knots, α_age)
    
    pH_knots = np.linspace(pH_center.min(), pH_center.max(), 7)
    α_pH = Exponential('α_pH', 1, value=0.1)
    spline_pH = spline('spline_pH', pH_center, pH_knots, α_pH)
    
    # Survival rates
    @deterministic
    def lam(b0=beta0, b=beta, x=X, γ=spline_age['mu_x'], ρ=spline_pH['mu_x']):
        return np.exp(b0 + np.dot(np.transpose(x), b) + γ + ρ)
    
    α = Exponential('α', 1, value=1)

    @observed
    def survival(value=obs_t, lam=lam, f=died, α=α):
        # Weibull survival log-likelihood
        return (f*(np.log(α) + (α-1)*np.log(value*lam) + np.log(lam)) - (lam*value)**α).sum()

    return locals()
In [144]:
W = MCMC(weibull_survival_model())
In [145]:
W.sample(50000, 40000)
 [                  2%                  ] 1159 of 50000 complete in 11.1 secHalting at iteration  1170  of  50000

Cox model

In [162]:
def cox_model():
    
    # Imputation of missing values
    p_va = Beta('p_VA', 1, 1, value=0.5)
    va_masked = np.ma.masked_values(va, value=0.5)
    x_va = Bernoulli('x_va', p_va, value=va_masked, observed=True)
    
    mu_pH = Normal('mu_pH', 0, 0.0001, value=7)
    sigma_pH = Uniform('sigma_pH', 0, 500, value=10)
    tau_pH = sigma_pH**-2
    pH_masked = np.ma.masked_values(pH_center, value=7.111)
    x_pH = Normal('x_pH', mu_pH, tau_pH, value=pH_masked, observed=True)

    X = [x_va, flu_unique.flu_only, flu_unique.flu_staph]
        
    # Intercept for survival rate
    beta0 = Normal('beta0', 0.0, 0.001, value=0)
    # Covariates
    beta = Normal('beta', 0, 0.001, value=np.zeros(len(X)))

    age_knots = np.linspace(AgeYears_center.min(), AgeYears_center.max(), 7)
    α_age = Exponential('α_age', 1, value=0.1)
    spline_age = spline('spline_age', AgeYears_center, age_knots, α_age)
    
    pH_knots = np.linspace(pH_center.min(), pH_center.max(), 7)
    α_pH = Exponential('α_pH', 1, value=0.1)
    spline_pH = spline('spline_pH', pH_center, pH_knots, α_pH)
    
    c = Gamma('c', .0001, .00001, value=0.1)
    r = Gamma('r', .001, .0001, value=0.1)
    
    dL0_star = Lambda('dL0_star', lambda r=r: r*np.diff(times))
    
    # prior mean hazard
    mu = Lambda('mu', lambda dL0_star=dL0_star, c=c: dL0_star * c)
    
    dL0 = Gamma('dL0', mu, c, value=np.ones(T))
    
    @deterministic
    def Idt(b0=beta0, b=beta, x=X, dL0=dL0, γ=spline_age['mu_x'], ρ=spline_pH['mu_x']): 
        # Poisson trick: independent log-normal hazard increments
        Xb = np.exp(b0 + np.dot(np.transpose(x), b) + γ + ρ)
        return np.transpose(Y[:,:-1] * np.outer(Xb, dL0))
#         return [[Y[i,j]*Xb[i] for i in range(N_obs)] for j in range(T)]

    dN_like = Poisson('dN_like', Idt, value=dN, observed=True)
    
    return locals()
In [163]:
M_cox = MCMC(cox_model())
In [164]:
M_cox.sample(20000, 10000)
 [-----------------100%-----------------] 20000 of 20000 complete in 483.4 sec
In [165]:
Matplot.summary_plot(M_cox.beta, custom_labels=['VA', 'flu only', 'flu + staph'], hpd=False)
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [166]:
Matplot.summary_plot(M_cox.__dict__['spline_pH']['gamma'], 
                     custom_labels=np.linspace(flu_unique.pH.dropna().min(), 
                                               flu_unique.pH.dropna().max(), 7).round(1).astype(str),
                    xlab='Effect size', main='pH effect')
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.
In [167]:
Matplot.summary_plot(M_cox.__dict__['spline_age']['gamma'], 
                     custom_labels=np.linspace(AgeYears.min(), AgeYears.max(), 7).round(1).astype(str),
                    xlab='Effect size', main='Age effect')
Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.