#!/usr/bin/env python # coding: utf-8 # In[2]: get_ipython().run_line_magic('matplotlib', 'inline') import pandas as pd import numpy as np import matplotlib.pyplot as plt from datetime import datetime import seaborn as sb import pymc as pm sb.set_style("white") # Import data # In[3]: hospitalized = pd.read_csv('data/hospitalized.csv', index_col=0) hospitalized.head() # Breakdown by: # # - month of age (and greater than 1 year) # - sex # - birthweight # - prematurity # - feeding: % with any breastfeeding # - maternal ed # - number in household # - daycare # - smoking exposure: cigarette, maternal during pregnacy, nargalia # - % with comorbidity: past medical history # - nationality: jordanian, palestinean, syrian, egptian, other # - diagnosis: pneumo, bronchiolitis, bronchopnewumonia # # Severity measures: # # - length of stay # - ICU # - O2 # - Mechanical vent. # - days of symptoms before admission # - % with antibiotics prior # - % with antibiotics during # # Columns: # # - RSV, MPV, Rhino, Influenza A/B/C, Adeno, Parinfluenza, Hospitalzied pop, no virus # In[4]: [v for v in hospitalized.columns if v.endswith('hx')] # In[5]: hospitalized['asthma'] = (hospitalized.child_asthma==1) | hospitalized.asthma_hx hospitalized.asthma.sum() # In[6]: hospitalized.no_hx.value_counts() # In[7]: ((hospitalized.no_hx==1) & (hospitalized.asthma.astype(bool))).sum() # Past medical history calculated as those who claim to have past medical history and those who claim to have no past medical history and reported asthma. # In[8]: hospitalized['past medical history'] = (~hospitalized.no_hx.astype(bool) | ((hospitalized.no_hx==1) & (hospitalized.asthma.astype(bool)))) hospitalized['past medical history'].sum() # In[9]: # Positive culture lookup pcr_lookup = {'pcr_result___1': 'RSV', 'pcr_result___2': 'HMPV', 'pcr_result___3': 'flu A', 'pcr_result___4': 'flu B', 'pcr_result___5': 'rhino', 'pcr_result___6': 'PIV1', 'pcr_result___7': 'PIV2', 'pcr_result___8': 'PIV3', 'pcr_result___13': 'H1N1', 'pcr_result___14': 'H3N2', 'pcr_result___15': 'Swine', 'pcr_result___16': 'Swine H1', 'pcr_result___17': 'flu C', 'pcr_result___18': 'Adeno'} # In[10]: hospitalized['RSV'] = hospitalized.pcr_result___1.astype(bool) hospitalized['HMPV'] = hospitalized.pcr_result___2.astype(bool) hospitalized['Rhino'] = hospitalized.pcr_result___5.astype(bool) hospitalized['Influenza'] = (hospitalized.pcr_result___3 | hospitalized.pcr_result___4 | hospitalized.pcr_result___17) hospitalized['Adeno'] = hospitalized.pcr_result___18.astype(bool) hospitalized['PIV'] = (hospitalized.pcr_result___6 | hospitalized.pcr_result___7 | hospitalized.pcr_result___8) hospitalized['No virus'] = hospitalized[list(pcr_lookup.keys())].sum(1) == 0 hospitalized['All'] = True # In[11]: viruses = ['RSV', 'HMPV', 'Rhino', 'Influenza', 'Adeno', 'PIV', 'No virus', 'All'] # In[12]: non_rsv_lookup = pcr_lookup.copy() non_rsv_lookup.pop('pcr_result___1') # Identify individuals with coinfection # In[13]: hospitalized['coinfection'] = hospitalized[list(pcr_lookup.keys())].sum(1) > 1 # Virus frequency by coinfection status # In[14]: hospitalized[(~hospitalized.coinfection) & (hospitalized.RSV)].shape[0]/hospitalized.shape[0] # In[15]: virus_by_coinf = hospitalized.groupby('coinfection')[viruses[:-2]].sum()/hospitalized.shape[0] virus_by_coinf.T.plot(kind='bar', stacked=True, grid=False, color=sb.color_palette('Paired', 2)[::-1]) # In[16]: hospitalized[viruses].sum(0) # In[17]: other_virus_index = (hospitalized[list(non_rsv_lookup.keys())].sum(1) > 0).astype(int) # In[18]: hospitalized['RSV'] = hospitalized['pcr_result___1'] hospitalized['non-RSV virus'] = (hospitalized['pcr_result___1']==0) & other_virus_index hospitalized['no virus'] = (hospitalized['pcr_result___1']==0) & (other_virus_index==0) # In[19]: hospitalized["prev_cond"] = (hospitalized[[c for c in hospitalized.columns if c.endswith('hx') and not c.startswith('no_')]].sum(1) > 0) # In[20]: hospitalized['male'] = (hospitalized.sex=='M').astype(int) age_groups = pd.get_dummies(pd.cut(hospitalized.age_months, [0,1,11,23])) age_groups.index = hospitalized.index age_groups.columns = 'under 2 months', '2-11 months', '12-23 months' hospitalized = hospitalized.join(age_groups) # In[21]: nationality_lookup = {1: 'Jordanian', 2: 'Egyptian', 3: 'Palestinian', 4: 'Iraqi', 5: 'Syrian', 6: 'Sudanese', 7: 'Russian', 8: 'Asian', 9: 'Other'} hospitalized['nationality'] = hospitalized.mother_nationality.replace(nationality_lookup) hospitalized['Jordanian'] = (hospitalized.nationality=='Jordanian').astype(int) hospitalized['Palestinian'] = (hospitalized.nationality=='Palestinian').astype(int) # In[22]: hospitalized['vitamin D < 20'] = (hospitalized.hospitalized_vitamin_d < 20).astype(int) hospitalized.loc[hospitalized.hospitalized_vitamin_d.isnull(), 'vitamin D < 20'] = np.nan hospitalized['vitamin D < 11'] = (hospitalized.hospitalized_vitamin_d < 11).astype(int) hospitalized.loc[hospitalized.hospitalized_vitamin_d.isnull(), 'vitamin D < 11'] = np.nan # In[23]: hospitalized['any_cigarette'] = (hospitalized.cigarette_smokers > 0).astype(int) # In[24]: hospitalized['any_smoke'] = (hospitalized.cigarette_smokers.astype(bool) | hospitalized.nargila_smokers.astype(bool)) # In[25]: hospitalized.any_smoke.mean() # In[26]: hospitalized['premature'] = (hospitalized.gest_age < 37).astype(int) # In[27]: hospitalized.length_of_stay.hist(bins=20) # Diagnosis # In[28]: hospitalized['ros'] = hospitalized.adm_sepsis | hospitalized.adm_febrile # In[29]: hospitalized['pertussis-like cough'] = hospitalized.adm_pertussis | hospitalized.adm_cough # In[30]: hospitalized['brochopneumonia'] = hospitalized.adm_bronchopneumo # In[31]: hospitalized['bronchiolitis'] = hospitalized.adm_bronchiolitis # In[32]: hospitalized['pneumonia'] = hospitalized.adm_pneumo # ## Rate Estimation # In[33]: hospitalized.admission_date = pd.to_datetime(hospitalized.admission_date) hospitalized.admission_date.describe() # Age groups: # # - Under 2 months # - 2-5 mo. # - 6-11 mo. # - Over 11 mo. # In[34]: age_groups = pd.get_dummies(pd.cut(hospitalized.age_months, [0,1,5,11,24], include_lowest=True)) age_groups.index = hospitalized.index age_groups.columns = 'age_under_2', 'age_2_5', 'age_6_11', 'age_over_11' hospitalized = hospitalized.join(age_groups) # In[35]: age_group_lookup = {'age_under_2': '<2', 'age_2_5': '2-5', 'age_6_11': '6-11', 'age_over_11': '>11'} # In[36]: customcmap = sb.color_palette("coolwarm", 6) # In[37]: fig, axes = plt.subplots(2,2, figsize=(10,6), sharey=True) for i,ax in enumerate(np.ravel(axes)): age_group = age_groups.columns[i] age_subset = hospitalized[hospitalized[age_group].astype(bool)] age_virus = age_subset[viruses[:-2]].mean() age_virus.T.plot(kind='bar', grid=False, ax=ax, color=customcmap) ax.set_title(age_group_lookup[age_group]) # ## Population rate estimation # # Recode year to virus season: # # - 2011: March 2010 - March 2011 # - 2012: Apr 2011 - Mar 2012 # - 2013: April 2012 - Mar 2013 # In[38]: hospitalized['virus_year'] = 2011 hospitalized.loc[(hospitalized.admission_date >= '2011-03-31') & (hospitalized.admission_date <= '2012-03-31'), 'virus_year'] = 2012 hospitalized.loc[hospitalized.admission_date > '2012-03-31', 'virus_year'] = 2013 hospitalized.virus_year.value_counts() # Recode zones # In[39]: zone_string = "1, Amman Zone 1 | 2, Abdoun Zone 1 | 3, Abu Alanda Zone 3 | 4, Abu Nusair Zone 2 | 5, Airport street Zone 4 | 6, Al.Ashrafeyeh Zone 5 | 7, Al.Badya Zone 31 | 8, Al.Baqa'a Zone 22 | 9, Al.Ghour Zone 32 | 10, Al.Hashmi Zone 6 | 11, Al.Hezam Zone 29 | 12, Al.Hussein Camping Zone 1 | 13, Al.Istiklal Zone 1 | 14, Al.Jeeza Zone 8 | 15, Al.Joufeh Zone 5 | 16, Al.Karak Zone 35 | 17, Al.Lubban Zone 16 | 18, Al.Madenah Al.Reyadeyeh Zone 1 | 19, Al.Mahatta Zone 6 | 20, Al.Manarah Zone 5 | 21, Al.Mareikh Zone 5 | 22, Al.Marqab Zone 6 | 23, Al.Muhajreen Zone 5 | 24, Al.Musdar Zone 5 | 25, Al.Musherfeh Zone 15 | 26, Al.Naser Zone 6 | 27, Al.Natheef Zone 5 | 28, Al.Nuzha Zone 1 | 29, Al.Qastal Zone 10 | 30, Al.Qwesmeh Zone 5 | 31, Al.Shouneh Zone 32 | 32, Al.Taj Zone 5 | 33, Al.Taybeh Zone 7 | 34, Aqaba Zone 30 | 35, Arjan Zone 1 | 36, Bayader Wadi AL.Seer Zone 20 | 37, Bnayyat Zone 12 | 38, D.Al.Ameer Ali Zone 13 | 39, D.Al.Aqsa Zone 9 | 40, D.Haj Hasan Zone 9 | 41, Daheyet Al.Rasheed Zone 17 | 42, Daheyet al.Yasmeen Zone 1 | 43, Dead Sea Zone 32 | 44, Deer.Al.Ghbar Zone 1 | 45, Down Town (Al.Balad) Zone 1 | 46, Dra'a al.Qarbi Zone 1 | 47, Ein Al.Basha Zone 22 | 48, Ein Ghazal Zone 6 | 49, Eskan Al.Ameer Hashem Zone 29 | 50, Eskan Al.Ameer Talal Zone 29 | 51, Etha'a wal Telvesion Zone 9 | 52, Hay Al.Dabaybeh Zone 5 | 53, Hay Al.Tafayleh Zone 5 | 54, Hay Nazzal Zone 1 | 55, Huttein (shneler ) Refugee camping Zone 6 | 56, Iraq Al.Ameer Zone 23 | 57, Jabal AL.Akhdar Zone 1 | 58, Jabal Al.Ameer Faisal Zone 29 | 59, Jabal AL.Hadeed Zone 5 | 60, Jabal Al.Hussein Zone 1 | 61, Jabal Al.Qosoor Zone 1 | 62, Jabal Amman Zone 1 | 63, Jarash Zone 36 | 64, Jawa Zone 7 | 65, Juwaydeh Zone 14 | 66, Khalda Zone 18 | 67, Khreibt Al.Souk Zone 7 | 68, Ma'an Zone 31 | 69, Madaba Zone 34 | 70, Mafraq Zone 33 | 71, Marj Al.Hamam Zone 19 | 72, Marka Zone 6 | 73, Muqableen Zone 9 | 74, Muwaqqar Zone 28 | 75, Nadi Al.Sebaq Zone 6 | 76, Naur Zone 19 | 77, Petra Zone 30 | 78, Qatranah Zone 30 | 79, Raghadan Zone 1 | 80, Ras Al.Ein Zone 1 | 81, Rusayfah Zone 29 | 82, Saffout Zone 22 | 83, Sahab Zone 11 | 84, Salheyet Al.Abed Zone 6 | 85, Shafa Badran Zone 25 | 86, Sharq Al.Awsat Zone 11 | 87, Shemasani Zone 1 | 88, Street 30 Zone 5 | 89, Summaya Street Zone 5 | 90, Suweileh Zone 27 | 91, Tabarbour Zone 21 | 92, Tla'a al Ali Zone 17 | 93, Um Al.Heran Zone 11 | 94, Um Al.Summaq Zone 17 | 95, Um Nuwwara and Adan Zone 5 | 96, Um Uthayna Zone 1 | 97, Wadi Abdoun Zone 1 | 98, Wadi AL.Haddadeh Zone 1 | 99, Wadi AL.Hajar Zone 29 | 100, Wadi Al.Remam Zone 5 | 101, Wadi AL.Seer Zone 20 | 102, Wadi Saqra Zone 1 | 103, Wehdat Zone 11 | 104, Yadoudeh Zone 13 | 105, Yajouz Zone 26 | 106, Zarqa Zone 29 | 107, Zezya Zone 24" # In[40]: zones = [z.strip().split(',') for z in zone_string.split('|')] # In[41]: zone_dict = {int(n):int(s.strip().split(' ')[-1]) for n,s in zones} # In[42]: hospitalized['zone'] = hospitalized.city_zone.replace(zone_dict) # Define Amman zone # In[43]: hospitalized['amman_zone'] = hospitalized.zone<28 # Hospitalized in Amman # In[44]: hospitalized_amman = hospitalized[hospitalized.amman_zone] # In[45]: hospitalized_amman.shape # In[46]: hospitalized_amman.index.is_unique # In[47]: hosp_age_counts = hospitalized_amman.groupby('admission_date')[age_groups.columns].sum().resample('M', how='sum').values # In[48]: pre_2013 = hospitalized_amman.admission_date < datetime(2013, 1, 1) # In[49]: age_counts = hospitalized_amman[pre_2013].groupby('admission_date')[age_groups.columns].sum().resample('M', how='sum') age_counts # Rates and demographics (via [World Bank](http://databank.worldbank.org)) and 2004 census data (via Jordan [Department of Statistics](http://www.dos.gov.jo/dos_home_e/main/population/census2004/group3/table_31.pdf)). # In[50]: # Jordan population, 2008-2012 population = 5786000, 5915000, 6046000, 6181000, 6318000 # Interpolated population by gender and age, 2010-2012 female_0 = 93649, 95739, 96486 male_0 = 98435, 100525, 101160 female_1 = 87941, 93511, 93067 male_1 = 92548, 98306, 97763 kids_0 = np.array(female_0) + np.array(male_0) kids_1 = np.array(female_1) + np.array(male_1) kids = kids_0 + kids_1 kids_under6mo = kids_6to12mo = kids_0/2. # Proportion in Amman amman_urban_2004 = 1784502. jordan_2004 = 5103639. amman_prop = amman_urban_2004 / jordan_2004 # Birth rates (per 1000) birth_rate = 29.665, 29.322, 28.869, 28.317, 27.699 # Neonatal mortality (per 1000) neonatal_mort = 12.7, 12.4, 12.1, 11.8, 11.5 # Infant mortality (per 1000) infant_mort = 18.4, 17.9, 17.3, 16.8, 16.4 # In[51]: births = np.array(population[-3:])/1000. * birth_rate[-3:] births # In[52]: births_6m = births/2. births_6m # In[53]: deaths_6m = births_6m/1000. * infant_mort[-3:] deaths_6m # In[54]: amman_prop # In[55]: (births_6m - deaths_6m)*amman_prop # In[56]: n = np.array((kids_under6mo, kids_6to12mo, kids_1, kids)) n # In[57]: kids_0 # In[58]: kids_1/kids_0 # In[59]: amman_prop # In[60]: n_amman = np.floor(n*amman_prop).astype(int) # In[61]: n_amman # Monthly admissions in Ammaon # In[62]: admissions_by_month = hospitalized_amman.groupby('admission_date')['child_name'].count().resample('1M', how='sum') # In[63]: admissions_by_month.head() # In[64]: # Dict to return pcr_result variable corresponding to virus virus_lookup = {pcr_lookup[k]: k for k in pcr_lookup.keys()} # In[65]: # Enrolling 5 days per week p_enroll = 5./7 # In[66]: rsv_subset = hospitalized_amman[hospitalized_amman.RSV==1] rsv_subset.shape # 'age_under_2', 'age_2_5', 'age_6_11', 'age_over_11' # In[67]: viruses # In[68]: virus_subset = hospitalized_amman[hospitalized['RSV']==1].copy() virus_subset['admission_year'] = virus_subset.admission_date.apply(lambda x: x.year) _under_6 = virus_subset[virus_subset.age_under_2.astype(bool) | virus_subset.age_2_5].groupby('admission_year')['child_name'].count() # In[69]: _6_11 = virus_subset[virus_subset.age_6_11.astype(bool)].groupby('admission_year')['child_name'].count() # In[70]: _over_11 = virus_subset[virus_subset.age_over_11.astype(bool)].groupby('admission_year')['child_name'].count() # In[71]: _all = virus_subset.groupby('admission_year')['child_name'].count() # In[72]: virus_subset.set_index(virus_subset.admission_date).groupby([pd.TimeGrouper('M')]).count()['mother_name'] # In[73]: virus_df = pd.concat([_under_6, _6_11, _over_11, _all], axis=1) virus_df.columns = ('under 6 mo.', '6-11 mo.', '11-23 mo.', 'all under 2 yr.') virus_df # In[74]: virus_df.values.ravel() # Use proportion in Amman to calculate proportion of kids in each age group and year in Amman # In[79]: def rate_model(virus): # Extract data subset for passed virus virus_subset = hospitalized_amman[hospitalized_amman[virus]==1].copy() # Create data frame of age x year counts u6 = virus_subset.age_under_2.astype(bool) | virus_subset.age_2_5 _under_6 = virus_subset[u6].groupby('virus_year')['child_name'].count() _6_11 = virus_subset[virus_subset.age_6_11.astype(bool)].groupby('virus_year')['child_name'].count() _over_11 = virus_subset[virus_subset.age_over_11.astype(bool)].groupby('virus_year')['child_name'].count() _all = virus_subset.groupby('virus_year')['child_name'].count() virus_df = pd.concat([_under_6, _6_11, _over_11, _all], axis=1) virus_df.columns = ('under 6 mo.', '6-11 mo.', '11-23 mo.', 'all ages') # Al Bashir hospital market share market_share = pm.Uniform('market_share', 0.5, 0.6) # Prior probability prev_virus = [pm.Beta('prev_virus_%i' %i, 1, 5) for i in range(virus_df.size)] per_1000 = pm.Lambda('per_1000', lambda p=prev_virus: np.array(p)*1000) # Correct for 5 days of enrollment per week p_hosp = market_share * (5./7) # RSV in Amman n_hosp = pm.Binomial('n_hosp', n_amman.T.ravel(), p_hosp, value=n_amman.T.ravel()*0.2) # Likelihood for number with RSV in hospital (assumes Pr(hosp | RSV) = 1) y_hosp = pm.Binomial('y_hosp', n_hosp, prev_virus, value=virus_df.values.ravel(), observed=True) return locals() # ### Influenza rates # In[80]: M = pm.MCMC(rate_model('Influenza')) # In[81]: M.sample(100000, 90000) # In[82]: M.sample(100000, 90000) # In[94]: prevalence_labels = ['%s %s' % (year, age) for year in ('2011', '2012', '2013') for age in ('under 6 mo.', '6-11 mo.', '12-23 mo.', 'all ages')] # In[95]: pm.Matplot.summary_plot(M.per_1000, custom_labels=prevalence_labels, main='Influenza (per 1000)') # In[96]: M.per_1000.summary() # In[116]: M.write_csv('Influenza_per_1000', variables=['per_1000']) # ### HMPV rates # In[83]: M_hmpv = pm.MCMC(rate_model('HMPV')) # M_hmpv. M_hmpv.sample(100000, 90000) # In[84]: M_hmpv.sample(100000, 90000) # In[97]: pm.Matplot.summary_plot(M_hmpv.per_1000, custom_labels=prevalence_labels, main='HMPV (per 1000)') # In[98]: M_hmpv.per_1000.summary() # In[115]: M_hmpv.write_csv('HMPV_per_1000', variables=['per_1000']) # ### Rhino rates # In[104]: M_rhino = pm.MCMC(rate_model('Rhino')) M_rhino.sample(150000, 140000) # In[105]: M_rhino.sample(150000, 140000) # In[106]: pm.Matplot.summary_plot(M_rhino.per_1000, custom_labels=prevalence_labels, main='Rhino (per 1000)') # In[108]: M_rhino.per_1000.summary() # In[114]: M_rhino.write_csv('Rhino_per_1000', variables=['per_1000']) # In[96]: # The data M_rhino.virus_df # ### RSV rates # In[87]: M_rsv = pm.MCMC(rate_model('RSV')) M_rsv.sample(100000, 90000) # In[88]: M_rsv.sample(100000, 90000) # In[102]: pm.Matplot.summary_plot(M_rsv.per_1000, custom_labels=prevalence_labels, main='RSV (per 1000)') # In[100]: M_rsv.per_1000.summary() # In[117]: M_rsv.write_csv('RSV_per_1000', variables=['per_1000']) # In[101]: M_rsv.virus_df # ### Adeno rates # In[89]: M_adeno = pm.MCMC(rate_model('Adeno')) M_adeno.sample(100000, 90000) # In[90]: M_adeno.sample(100000, 90000) # In[119]: pm.Matplot.summary_plot(M_adeno.per_1000, custom_labels=prevalence_labels, main='Adeno (per 1000)') # In[120]: M_adeno.per_1000.summary() # In[118]: M_adeno.write_csv('Adeno_per_1000', variables=['per_1000']) # In[106]: M_adeno.virus_df # ### PIV rates # In[91]: M_piv = pm.MCMC(rate_model('PIV')) M_piv.sample(100000, 90000) # In[92]: M_piv.sample(100000, 90000) # In[121]: pm.Matplot.summary_plot(M_piv.per_1000, custom_labels=prevalence_labels, main='PIV (per 1000)') # In[110]: M_piv.per_1000.summary() # In[122]: M_piv.write_csv('PIV_per_1000', variables=['per_1000']) # In[111]: M_piv.virus_df # ## Hospitalization # # Import hospitalization data # In[123]: spreadsheets = get_ipython().getoutput('ls data/Al-Bashir*') data = [] col_names = 'date', 'admissions', 'admissions_2', 'resp_admissions', 'resp_admissions_2' for spreadsheet in spreadsheets: this_file = pd.ExcelFile(spreadsheet) # Need to fix data entry error in years > 2011 dot = spreadsheet.find('.') fix_columns = int(spreadsheet[dot-4:dot]) > 2011 for name in this_file.sheet_names: d = this_file.parse(name) d.columns = col_names if fix_columns: total = d['resp_admissions'] + d['resp_admissions_2'] d['resp_admissions_2'] = d['resp_admissions'] d['resp_admissions'] = total data.append(d) hospitalizations = pd.concat(data).set_index('date') # In[124]: hospitalizations.head() # In[125]: hospitalizations['virus_year'] = 2011 hospitalizations.loc[(hospitalizations.index >= datetime(2011, 3, 31)) & (hospitalizations.index <= datetime(2012, 3, 31)), 'virus_year'] = 2012 hospitalizations.loc[hospitalizations.index > datetime(2012, 3, 31), 'virus_year'] = 2013 # In[126]: under_2_hosp = hospitalizations.groupby('virus_year').sum()['admissions_2'] under_2_hosp # In[127]: rsv_by_year = rsv_subset.groupby('virus_year')[['age_under_2', 'age_2_5', 'age_6_11', 'age_over_11']] rsv_counts = rsv_by_year.sum() # In[128]: rsv_counts # In[129]: under2_rsv = pd.concat([rsv_counts, under_2_hosp], axis=1).dropna() under2_rsv # In[130]: def extract_virus(virus): virus_subset = hospitalized[hospitalized[virus]==1] virus_by_year = virus_subset.groupby('virus_year')[['age_under_2', 'age_2_5', 'age_6_11', 'age_over_11']] virus_counts = virus_by_year.sum() return pd.concat([virus_counts, under_2_hosp], axis=1).dropna() # In[131]: def hosp_subset(virus): data_subset = extract_virus(virus) admissions = data_subset.pop('admissions_2').values data_subset = data_subset.values n_months, n_ages = data_subset.shape # Estimate age distribution of admissions p_age = pm.Dirichlet('p_age', np.ones(4), value=data_subset.sum(0)[:-1]/data_subset.sum()) counts = pm.Multinomial('counts', data_subset.sum(1), p_age, value=data_subset, observed=True) # Estimate denominators n_hosp = [] for i,ni in enumerate(admissions): d = data_subset[i] n_init = (ni*(d/d.sum() if (d.sum()>10) else np.ones(4)*0.25)).astype(int)[:-1] n_init = np.append(n_init, ni - n_init.sum()) n_hosp.append(pm.Multinomial('n_hosp_{}'.format(i), ni, p_age, value=n_init)) n_hosp_t = pm.Lambda('n_hosp_t', lambda x=n_hosp: np.array([[xij for xij in xi] for xi in x]).T) # Virus rates by age and time mu_alpha = pm.Normal('mu_alpha', 0, 0.01, value=[0]*n_ages) sigma_alpha = pm.Uniform('sigma_alpha', 0, 100, value=[10]*n_ages) rho = pm.Uniform('rho', -1, 1, value=0) mu = [pm.Lambda('mu_{}'.format(i), lambda mu=mu_alpha: np.array([mu[i]]*n_months)) for i in range(n_ages)] off_diag = np.eye(n_months, k=1) Sigma = [pm.Lambda('Sigma_{}'.format(i), lambda s=sigma_alpha, r=rho: (np.eye(n_months) + off_diag*r + off_diag.T*r)*(s[i]**2)) for i in range(n_ages)] alpha = [pm.MvNormalCov('alpha_{}'.format(i), mu[i], Sigma[i], value=[0]*n_months) for i in range(n_ages)] p_virus = [pm.Lambda('p_virus_{}'.format(i), lambda a=a: pm.invlogit(a)) for i,a in enumerate(alpha)] # Viral rate likelihood @pm.observed def rate_like(value=data_subset.T, p=p_virus, n=n_hosp_t): return np.sum([pm.binomial_like(value[i], n[i], p[i]) for i in range(n_ages)]) return(locals()) # In[132]: M = pm.MCMC(hosp_subset('RSV')) # In[133]: M.sample(20000, 10000) # In[134]: pm.Matplot.summary_plot(M.p_age, custom_labels=['age_under_2', 'age_2_5', 'age_6_11']) # In[135]: date_labels = ['2011', '2012', '2013'] pm.Matplot.summary_plot(M.p_virus[0], custom_labels=date_labels, main='RSV rates') # In[136]: import seaborn as sb # In[137]: age_group_labels = ['Under 2 mo.', '2-5 mo.', '6-11 mo.', '12-23 mo.'] # In[138]: rsv_rates = pd.concat([pd.DataFrame({'rate':p.trace().T[t], 'age': age_group_labels[i], 'year': date_labels[t]}) for i,p in enumerate(M.p_virus) for t in range(3)]) # In[139]: sb.factorplot('year', 'rate', 'age', rsv_rates, kind='box', aspect=1.75) # In[140]: M_rhino = pm.MCMC(hosp_subset('Rhino')) M_rhino.sample(20000, 10000) # In[141]: rhino_rates = pd.concat([pd.DataFrame({'rate':p.trace().T[t], 'age': age_group_labels[i], 'year': date_labels[t]}) for i,p in enumerate(M_rhino.p_virus) for t in range(3)]) # In[142]: sb.set(style="white", palette="muted") fg = sb.factorplot('year', 'rate', 'age', rhino_rates, kind='box', aspect=1.75) # ### Production runs # In[143]: for virus in viruses: print('\nRunning', virus) M = pm.MCMC(hosp_subset(virus)) M.sample(50000, 40000) sb.set(style="white", palette="muted") virus_rates = pd.concat([pd.DataFrame({'rate':p.trace().T[t], 'age': age_group_labels[i], 'year': date_labels[t]}) for i,p in enumerate(M.p_virus) for t in range(3)]) fg = sb.factorplot('year', 'rate', 'age', virus_rates, kind='box', aspect=1.75) fg.savefig(virus) M.write_csv(virus, variables=[x.__name__ for x in M.p_virus])