%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc3 as pm
import pdb
pm.__version__
'3.0'
flu = pd.read_csv('data/flu_organism.csv', index_col=0)
flu.shape
(2933, 63)
flu.PatientID.unique().shape
(1654,)
Create indicator for flu-only.
flu_only = flu.groupby('PatientID')['OrganismName'].apply(lambda s:
len([x for x in s if str(x).startswith('Influenza')])==len(s)).astype(bool)
flu_only.head()
PatientID 001B0EC6-66F5-48C9-8516-52D2ABFE28AA False 0034181F-59AB-4B2C-A06A-A83EE1DF1A17 True 004AE387-1096-410D-A7CB-F9A9086AA385 False 0084DCF0-8E3C-4F1C-8650-BE795A57EF2F False 0084E706-90D0-4314-8C99-496171AB639D True Name: OrganismName, dtype: bool
flu_only.name = 'flu_only'
flu = flu.merge(pd.DataFrame(flu_only), left_on='PatientID', right_index=True)
flu_staph_organism = flu.groupby('PatientID')['OrganismName'].apply(lambda s:
len([x for x in s if str(x).startswith('Influenza') or
str(x).startswith('Staphylococcus aureus')])>1).astype(bool)
flu_staph_organism.name = 'flu_staph_org'
flu = flu.merge(pd.DataFrame(flu_staph_organism), left_on='PatientID', right_index=True)
flu_staph_organism.mean()
0.13059250302297462
See if there are any flu-only or flu-other individuals with an ICD9 diagnosis of S. Aureus
flu_other = flu.groupby('PatientID')['OrganismName'].apply(lambda s:
len([x for x in s if str(x).startswith('Influenza') or
str(x).startswith('Other')])>1).astype(bool)
flu_other.name = 'flu_other'
flu = flu.merge(pd.DataFrame(flu_other), left_on='PatientID', right_index=True)
There do not appear to be cases of flu-only or flu-other via organism that have a S.Aureus diagnosis via ICD9
(flu_other & flu.s_aureus_icd9).sum()
0
(flu_only & flu.s_aureus_icd9).sum()
0
Same for pneumo
(flu_other & flu.pneumo_icd9).sum()
0
flu.groupby('PatientID')['OrganismName'].value_counts()
PatientID 001B0EC6-66F5-48C9-8516-52D2ABFE28AA Klebsiella pneumoniae 1 Enterococcus 1 0034181F-59AB-4B2C-A06A-A83EE1DF1A17 Influenza A 1 004AE387-1096-410D-A7CB-F9A9086AA385 Candida albicans 1 Influenza A 1 Pseudomonas aerugenosa 1 0084DCF0-8E3C-4F1C-8650-BE795A57EF2F Influenza A 1 Staphylococcus aureus 1 0084E706-90D0-4314-8C99-496171AB639D Influenza B 1 00EFD11D-650A-4317-8C4F-4B7378AD7946 Other 2 0118B486-8ACE-4BFF-8ED6-753D96E6B441 Influenza A 1 011B5A5D-9967-44A7-92F9-02F82D31216D Candida albicans 2 Influenza A 1 Epstein-Barr virus (EBV) 1 01AC9644-7CDF-4E28-A5F4-FE5B7F9FA22A Influenza A 1 Streptococcus pneumoniae 1 Moraxella catarrhalis (Branhamella) 1 Staphylococcus, coag neg 1 01F84F52-92D3-403B-B8AC-21A72A4BC771 Influenza A 1 02373407-74E6-4A1E-9629-F4DCBF707B0C Influenza A 1 025D87C9-CF70-4B9A-94BE-B15EA2185472 Eschericha coli 2 Influenza A 1 Enterococcus 1 02A94D2E-768A-4ABB-86BF-C8A188D698CC Influenza A 1 02C9FF5A-A4D6-461A-A650-E3EB405D62AE Influenza A 2 02D33150-FDDB-4F12-AAAB-D6F4B8765EBF Influenza A 1 031C1231-B585-478F-A98F-35E017B788CD Other 2 035A053E-91B1-4B68-87D4-F92C675FE6E3 Other 2 03AA4AD0-1B0E-4CAD-BD76-03FE52AA580D Influenza A 1 Pneumocystis carinii 1 .. FC9F84B4-B2CD-4C28-B824-D5CE1AF7BEF1 Gram positive, other 2 Aspergillus fumigatus 1 Candida albicans 1 Enterococcus 1 Pseudomonas fluorescens 1 Other 1 Yeast sp. 1 Stenotrophomonas maltophilia (Xanthomonas) 1 FCBAEC19-EFA1-4A95-A98D-1D8FE65B53CE Influenza A 1 FCD795A0-9FD1-4C20-961B-04BEF1B23E96 Influenza A 1 Staphylococcus aureus 1 FCDF1967-C307-4FF1-8CF1-7AA02FE0BE47 Influenza A 1 FD2FA339-7417-4F13-811B-2E1873E529C7 Influenza A 1 FDC32D1D-0D67-4752-A2E1-72E790FCB797 Pseudomonas aerugenosa 1 Influenza A 1 FDED68D1-1DFF-4899-9CF9-797E6AE70D18 Influenza A 1 FDFF23E3-C0CB-468D-98D9-998192CA7590 Eschericha coli 1 Candida albicans 1 Influenza A 1 Yeast sp. 1 FEA94F00-083A-479C-B862-05274A9842B5 Streptococcus, group A 1 Influenza A 1 Staphylococcus aureus 1 FF0D88BF-E131-4E56-9E78-CC6B60E461FD Influenza A 1 Staphylococcus, coag neg 1 Yeast sp. 1 FF0FB831-9804-4506-AE0B-B447AB54D4D5 Hemophilus influenza 1 FF528500-5274-4225-96C2-CF1153EE3AB0 Other 1 FF6CD7E8-8956-4C15-AF6C-3033A38C48D7 Influenza A 1 Staphylococcus epidermidis 1 dtype: int64
Create indictors for coinfection type
for org in flu.Type.unique():
flu[org] = (flu.Type==org).astype(int)
Create data frame of unique patients
flu_unique = flu.drop_duplicates(subset=['PatientID']).set_index('PatientID')
flu_unique.s_aureus_icd9.sum()
99
flu_unique['flu_only'] = flu_only
flu_unique.flu_only.mean()
0.34159613059250304
flu_unique.flu_staph.mean()
0.13059250302297462
Several missing values for admission to time on ECMO
flu_unique.AdmitToTimeOnHours.isnull().mean()
0.16686819830713423
Since we need this field to calculate event time, we will have to drop individuals with missing values from the survival analysis.
flu_unique = flu_unique.dropna(subset=['AdmitToTimeOnHours', 'HoursECMO'])
assert not flu_unique.AdmitToTimeOnHours.isnull().sum()
Create variables for use in analysis
# ECMO Type
flu_unique['VA'] = flu_unique.Mode.isin(['VA', 'VV-VA'])
# Set "Other" type to NA (there are only a couple)
flu_unique.loc[flu_unique.Mode=='Other', 'VA'] = None
Create oxygen index
flu_unique['OI'] = flu_unique.FiO2 * flu_unique.MAP / flu_unique.PO2
flu_unique.OI.hist(bins=20)
<matplotlib.axes._subplots.AxesSubplot at 0x117075f60>
covariates = ['AgeYears', 'pH', 'OI', 'VA']
Get counts of missing values
flu_unique[covariates].isnull().mean()
AgeYears 0.000000 pH 0.077323 OI 0.442379 VA 0.011152 dtype: float64
flu_unique[covariates].hist(bins=25);
flu_unique[flu_unique.AgeYears< 1].AgeDays.hist(bins=15)
<matplotlib.axes._subplots.AxesSubplot at 0x116eb6f28>
flu_unique.flu_staph_org.sum()
172
flu_unique.s_aureus_icd9.sum()
84
(flu_unique.flu_staph_org | flu_unique.s_aureus_icd9).sum()
218
Fate of each patient
died = (flu_complete.fate=='Died').astype(int).values
N_obs = len(died)
Functions for standarizing and centering
center = lambda x: (x - x[np.isnan(x) ^ True].mean())
standardize = lambda x: center(x) / x[np.isnan(x) ^ True].std()
covariates
['AgeYears', 'pH', 'OI', 'VA']
AgeYears, pH, OI, va = flu_complete[covariates].values.T
AgeYears_center = center(AgeYears).round(1)
# Center pH at 7.4
pH_center = pH - 7.4
OI_std = standardize(OI)
no_coinfection = flu_complete.flu_only.values
with_staph = flu_complete.flu_staph.values
flu_unique.flu_staph.sum()
172
# Values for missing elements, to be replaced in model
OI_std[np.isnan(OI_std)] = -999
pH_center[np.isnan(pH_center)] = -999
va[np.isnan(va)] = -999
from pymc3 import Model, find_MAP, NUTS, Metropolis, sample, forestplot, traceplot, Slice
from pymc3 import Normal, HalfCauchy, Beta, Bernoulli, MvNormal
from pymc3.distributions.timeseries import GaussianRandomWalk
from numpy.ma import masked_values
def interpolate(x0, y0, x):
idx = np.searchsorted(x0, x)
dl = np.array(x - x0[idx - 1])
dr = np.array(x0[idx] - x)
d = dl + dr
wl = dr / d
return wl*y0[idx-1] + (1-wl)*y0[idx]
import theano.tensor as T
def invlogit(x):
return 1. / (1 + T.exp(-x))
age_d = np.array([[xi - yi for yi in AgeYears_center] for xi in AgeYears_center])
with Model() as model:
# Impute missing values
μ_pH = Normal('μ_pH', 0, 0.001, testval=0)
σ_pH = HalfCauchy('σ_pH', 5, testval=1)
pH_imputed = Normal('pH_imputed', μ_pH, σ_pH**-2, observed=masked_values(pH_center, value=-999))
p_va = Beta('p_va', 1, 1)
va_imputed = Bernoulli('va_imputed', p_va, observed=masked_values(va, value=-999))
# Non-linear age effect
'''
nknots = 10
knots = np.linspace(AgeYears_center.min(), AgeYears_center.max(), nknots)
age_sd = HalfCauchy('age_sd', 5, testval=1)
age_knots = GaussianRandomWalk('age_knots', sd=age_sd, shape=nknots)
age_effects = interpolate(knots, age_knots, AgeYears_center)
'''
# Parameters of kernel function
η2 = pm.HalfCauchy('η2', 5, testval=1)
ρ2 = pm.HalfCauchy('ρ2', 5, testval=1)
σ2 = pm.HalfCauchy('σ2', 5, testval=1)
# Construct precision matrix
M = η2 * T.exp(-ρ2 * age_d**2)
S = M + T.identity_like(M)*σ2
τ = T.nlinalg.matrix_inverse(S)
# GP of age effect
age_effects = MvNormal('age_effects', np.zeros(N_obs), τ, shape=N_obs)
μ = Normal('μ', 0, 0.01, testval=0)
# Linear covariates
β = Normal('β', 0, 0.01, shape=4)
# Probability of mortality effect
π = invlogit(μ + β[0]*no_coinfection + β[1]*with_staph + β[2]*pH_imputed + β[3]*va_imputed + age_effects)
deaths = Bernoulli('deaths', π, observed=died)
with model:
start = find_MAP()
#step1 = NUTS([β, μ_pH, σ_pH, p_va, age_sd, age_knots, μ], scaling=start)
step1 = NUTS([β, μ_pH, σ_pH, p_va, η2, ρ2, σ2, μ], scaling=start)
step2 = Slice([va_imputed.missing_values,
pH_imputed.missing_values])
trace = sample(5000, [step1, step2], start)
--------------------------------------------------------------------------- LinAlgError Traceback (most recent call last) /usr/local/lib/python3.4/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs) 594 try: --> 595 outputs = self.fn() 596 except Exception: /usr/local/lib/python3.4/site-packages/theano/gof/op.py in rval(p, i, o, n) 767 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): --> 768 r = p(n, [x[0] for x in i], o) 769 for o in node.outputs: /usr/local/lib/python3.4/site-packages/theano/tensor/nlinalg.py in perform(self, node, xxx_todo_changeme2, xxx_todo_changeme3) 76 (z, ) = xxx_todo_changeme3 ---> 77 z[0] = numpy.linalg.inv(x).astype(x.dtype) 78 /usr/local/lib/python3.4/site-packages/numpy/linalg/linalg.py in inv(a) 519 extobj = get_linalg_error_extobj(_raise_linalgerror_singular) --> 520 ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) 521 return wrap(ainv.astype(result_t)) /usr/local/lib/python3.4/site-packages/numpy/linalg/linalg.py in _raise_linalgerror_singular(err, flag) 89 def _raise_linalgerror_singular(err, flag): ---> 90 raise LinAlgError("Singular matrix") 91 LinAlgError: Singular matrix During handling of the above exception, another exception occurred: LinAlgError Traceback (most recent call last) <ipython-input-98-ffaed00ebadd> in <module>() 1 with model: ----> 2 start = find_MAP() 3 #step1 = NUTS([β, μ_pH, σ_pH, p_va, age_sd, age_knots, μ], scaling=start) 4 step1 = NUTS([β, μ_pH, σ_pH, p_va, η2, ρ2, σ2, μ], scaling=start) 5 /Users/fonnescj/GitHub/pymc3/pymc3/tuning/starting.py in find_MAP(start, vars, fmin, return_raw, disp, model, *args, **kwargs) 79 if 'fprime' in getargspec(fmin).args: 80 r = fmin(logp_o, bij.map( ---> 81 start), fprime=grad_logp_o, disp=disp, *args, **kwargs) 82 else: 83 r = fmin(logp_o, bij.map(start), disp=disp, *args, **kwargs) /usr/local/lib/python3.4/site-packages/scipy/optimize/optimize.py in fmin_bfgs(f, x0, fprime, args, gtol, norm, epsilon, maxiter, full_output, disp, retall, callback) 782 'return_all': retall} 783 --> 784 res = _minimize_bfgs(f, x0, args, fprime, callback=callback, **opts) 785 786 if full_output: /usr/local/lib/python3.4/site-packages/scipy/optimize/optimize.py in _minimize_bfgs(fun, x0, args, jac, callback, gtol, norm, eps, maxiter, disp, return_all, **unknown_options) 855 alpha_k, fc, gc, old_fval, old_old_fval, gfkp1 = \ 856 _line_search_wolfe12(f, myfprime, xk, pk, gfk, --> 857 old_fval, old_old_fval) 858 except _LineSearchError: 859 # Line search failed to find a better solution. /usr/local/lib/python3.4/site-packages/scipy/optimize/optimize.py in _line_search_wolfe12(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwargs) 690 ret = line_search_wolfe1(f, fprime, xk, pk, gfk, 691 old_fval, old_old_fval, --> 692 **kwargs) 693 694 if ret[0] is None: /usr/local/lib/python3.4/site-packages/scipy/optimize/linesearch.py in line_search_wolfe1(f, fprime, xk, pk, gfk, old_fval, old_old_fval, args, c1, c2, amax, amin, xtol) 94 stp, fval, old_fval = scalar_search_wolfe1( 95 phi, derphi, old_fval, old_old_fval, derphi0, ---> 96 c1=c1, c2=c2, amax=amax, amin=amin, xtol=xtol) 97 98 return stp, fc[0], gc[0], fval, old_fval, gval[0] /usr/local/lib/python3.4/site-packages/scipy/optimize/linesearch.py in scalar_search_wolfe1(phi, derphi, phi0, old_phi0, derphi0, c1, c2, amax, amin, xtol) 165 if task[:2] == b'FG': 166 alpha1 = stp --> 167 phi1 = phi(stp) 168 derphi1 = derphi(stp) 169 else: /usr/local/lib/python3.4/site-packages/scipy/optimize/linesearch.py in phi(s) 80 def phi(s): 81 fc[0] += 1 ---> 82 return f(xk + s*pk, *args) 83 84 def derphi(s): /usr/local/lib/python3.4/site-packages/scipy/optimize/optimize.py in function_wrapper(*wrapper_args) 280 def function_wrapper(*wrapper_args): 281 ncalls[0] += 1 --> 282 return function(*(wrapper_args + args)) 283 284 return ncalls, function_wrapper /Users/fonnescj/GitHub/pymc3/pymc3/tuning/starting.py in logp_o(point) 71 72 def logp_o(point): ---> 73 return nan_to_high(-logp(point)) 74 75 def grad_logp_o(point): /Users/fonnescj/GitHub/pymc3/pymc3/blocking.py in __call__(self, x) 117 118 def __call__(self, x): --> 119 return self.fa(self.fb(x)) /Users/fonnescj/GitHub/pymc3/pymc3/model.py in __call__(self, state) 312 313 def __call__(self, state): --> 314 return self.f(**state) 315 316 class LoosePointFunc(object): /usr/local/lib/python3.4/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs) 604 self.fn.nodes[self.fn.position_of_error], 605 self.fn.thunks[self.fn.position_of_error], --> 606 storage_map=self.fn.storage_map) 607 else: 608 # For the c linker We don't have access from /usr/local/lib/python3.4/site-packages/theano/gof/link.py in raise_with_op(node, thunk, exc_info, storage_map) 204 exc_value = exc_type(str(exc_value) + detailed_err_msg + 205 '\n' + '\n'.join(hints)) --> 206 raise exc_type(exc_value).with_traceback(exc_trace) 207 208 /usr/local/lib/python3.4/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs) 593 t0_fn = time.time() 594 try: --> 595 outputs = self.fn() 596 except Exception: 597 if hasattr(self.fn, 'position_of_error'): /usr/local/lib/python3.4/site-packages/theano/gof/op.py in rval(p, i, o, n) 766 # default arguments are stored in the closure of `rval` 767 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): --> 768 r = p(n, [x[0] for x in i], o) 769 for o in node.outputs: 770 compute_map[o][0] = True /usr/local/lib/python3.4/site-packages/theano/tensor/nlinalg.py in perform(self, node, xxx_todo_changeme2, xxx_todo_changeme3) 75 (x,) = xxx_todo_changeme2 76 (z, ) = xxx_todo_changeme3 ---> 77 z[0] = numpy.linalg.inv(x).astype(x.dtype) 78 79 def grad(self, inputs, g_outputs): /usr/local/lib/python3.4/site-packages/numpy/linalg/linalg.py in inv(a) 518 signature = 'D->D' if isComplexType(t) else 'd->d' 519 extobj = get_linalg_error_extobj(_raise_linalgerror_singular) --> 520 ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) 521 return wrap(ainv.astype(result_t)) 522 /usr/local/lib/python3.4/site-packages/numpy/linalg/linalg.py in _raise_linalgerror_singular(err, flag) 88 89 def _raise_linalgerror_singular(err, flag): ---> 90 raise LinAlgError("Singular matrix") 91 92 def _raise_linalgerror_nonposdef(err, flag): LinAlgError: Singular matrix Apply node that caused the error: MatrixInverse(Elemwise{Composite{((i0 * exp((i1 * i2 * i3))) + (i4 * i5))}}.0) Inputs types: [TensorType(float64, matrix)] Inputs shapes: [(1345, 1345)] Inputs strides: [(10760, 8)] Inputs values: ['not shown'] Backtrace when the node is created: File "<ipython-input-97-90b626f0aaef>", line 29, in <module> τ = T.nlinalg.matrix_inverse(S) HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.
forestplot(trace, vars=['β'], ylabels=['no coinfection', 'staph coinfection', 'pH', 'VA'])
<matplotlib.gridspec.GridSpec at 0x11c9389b0>
traceplot(trace, vars=['age_sd', 'σ_pH'])
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x1157bbc18>, <matplotlib.axes._subplots.AxesSubplot object at 0x11cb93b70>], [<matplotlib.axes._subplots.AxesSubplot object at 0x11cbde5c0>, <matplotlib.axes._subplots.AxesSubplot object at 0x11cc1f438>]], dtype=object)