Predictions of COVID-19 Growth Rates Using Bayesian Modeling
Note: This dashboard contains the results of a predictive model. The author has tried to make it as accurate as possible. But the COVID-19 situation is changing quickly, and these models inevitably include some level of speculation.
#hide
from pathlib import Path
loadpy = Path('load_covid_data.py')
if not loadpy.exists():
! wget https://raw.githubusercontent.com/github/covid19-dashboard/master/_notebooks/load_covid_data.py
#hide
%matplotlib inline
import numpy as np
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import seaborn as sns
import arviz as az
import pymc3 as pm
import altair as alt
import load_covid_data
sns.set_context('talk')
plt.style.use('seaborn-whitegrid')
## Set this to true to see legacy charts
debug=False
WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
#hide
df = load_covid_data.load_data(drop_states=True, filter_n_days_100=2)
# We only have data for China after they already had a significant number of cases.
# They also are not well modeled by the exponential, so we drop them here for simplicity.
df = df.loc[lambda x: x.country != 'China (total)']
countries = df.country.unique()
n_countries = len(countries)
df = df.loc[lambda x: (x.days_since_100 >= 0)]
annotate_kwargs = dict(
s='Based on COVID Data Repository by Johns Hopkins CSSE ({})\nBy Thomas Wiecki'.format(df.index.max().strftime('%B %d, %Y')),
xy=(0.05, 0.01), xycoords='figure fraction', fontsize=10)
These are the countries included in the model:
#hide_input
', '.join(sorted(df.country.unique().tolist()))
'Albania, Algeria, Andorra, Argentina, Armenia, Australia (total), Austria, Bahrain, Belgium, Bosnia and Herzegovina, Brazil, Brunei, Bulgaria, Burkina Faso, Canada (total), Chile, Colombia, Costa Rica, Croatia, Cyprus, Czechia, Denmark, Denmark (total), Diamond Princess, Dominican Republic, Ecuador, Egypt, Estonia, Finland, France, France (total), Germany, Greece, Hong Kong, Hungary, Iceland, India, Indonesia, Iran, Iraq, Ireland, Israel, Italy, Japan, Jordan, Korea, South, Kuwait, Latvia, Lebanon, Lithuania, Luxembourg, Malaysia, Malta, Mexico, Moldova, Morocco, Netherlands, New Zealand, North Macedonia, Norway, Pakistan, Panama, Peru, Philippines, Poland, Portugal, Qatar, Romania, Russia, San Marino, Saudi Arabia, Serbia, Singapore, Slovakia, Slovenia, South Africa, Spain, Sri Lanka, Sweden, Switzerland, Taiwan*, Thailand, Tunisia, Turkey, US, Ukraine, United Arab Emirates, United Kingdom, United Kingdom (total), Uruguay, Vietnam'
#hide
#####################################
##### This Cell Runs The Model ######
#####################################
with pm.Model() as model:
############
# Intercept
# Group mean
a_grp = pm.Normal('a_grp', 100, 50)
# Group variance
a_grp_sigma = pm.HalfNormal('a_grp_sigma', 50)
# Individual intercepts
a_ind = pm.Normal('a_ind',
mu=a_grp, sigma=a_grp_sigma,
shape=n_countries)
########
# Slope
# Group mean
b_grp = pm.Normal('b_grp', 1.33, .5)
# Group variance
b_grp_sigma = pm.HalfNormal('b_grp_sigma', .5)
# Individual slopes
b_ind = pm.Normal('b_ind',
mu=b_grp, sigma=b_grp_sigma,
shape=n_countries)
# Error
sigma = pm.HalfNormal('sigma', 500., shape=n_countries)
# Create likelihood for each country
for i, country in enumerate(countries):
df_country = df.loc[lambda x: (x.country == country)]
# By using pm.Data we can change these values after sampling.
# This allows us to extend x into the future so we can get
# forecasts by sampling from the posterior predictive
x = pm.Data(country + "x_data",
df_country.days_since_100.values)
confirmed = pm.Data(country + "y_data",
df_country.confirmed.astype('float64').values)
# Likelihood
pm.NegativeBinomial(
country,
(a_ind[i] * b_ind[i] ** x), # Exponential regression
sigma[i],
observed=confirmed)
#hide
with model:
# Sample posterior
trace = pm.sample(tune=1500, chains=1, cores=2, target_accept=.9)
# Update data so that we get predictions into the future
for country in countries:
df_country = df.loc[lambda x: (x.country == country)]
x_data = np.arange(0, 30)
y_data = np.array([np.nan] * len(x_data))
pm.set_data({country + "x_data": x_data})
pm.set_data({country + "y_data": y_data})
# Sample posterior predictive
post_pred = pm.sample_posterior_predictive(trace, samples=80)
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [sigma, b_ind, b_grp_sigma, b_grp, a_ind, a_grp_sigma, a_grp] Sampling chain 0, 0 divergences: 100%|██████████| 2000/2000 [15:30<00:00, 2.15it/s] Only one chain was sampled, this makes it impossible to run some convergence checks /opt/hostedtoolcache/Python/3.6.10/x64/lib/python3.6/site-packages/pymc3/sampling.py:1247: UserWarning: samples parameter is smaller than nchains times ndraws, some draws and/or chains may not be represented in the returned posterior predictive sample "samples parameter is smaller than nchains times ndraws, some draws " 100%|██████████| 80/80 [01:24<00:00, 1.06s/it]
Select a country from the drop down list below to toggle the visualization.
#hide
##############################################
#### Pre processing of Data For Altair Viz ###
##############################################
# Flatten predictions & target for each country into a pandas DataFrame
prediction_dfs_list = []
for country in post_pred:
arr = post_pred[country]
preds = arr.flatten().tolist() # get predictions in a flattened array
pred_idx = np.indices(arr.shape)[0].flatten().tolist() # prediction for model (there are many per country, thes are the grey lines)
days_since = np.indices(arr.shape)[1].flatten().tolist() # days since 100 cases
pred_df = pd.DataFrame({'country': country,
'predictions': preds,
'pred_idx': pred_idx,
'days_since_100': days_since}
)
prediction_dfs_list.append(pred_df)
predictionsDF = pd.concat(prediction_dfs_list)
# Compute the maximum value to plot on the y-axis as 15x the last confirmed case
ylims = pd.DataFrame(df.groupby('country').last().confirmed * 15).reset_index()
ylims.columns = ['country', 'ylim']
# Filter out any predictions exceed the y-axis limit
predictionsDF_filtered = (predictionsDF.merge(ylims, on='country', how='left')
.loc[lambda x: x.predictions <= x.ylim])
# Compute a 33% daily growth rate (dashed line) as a reference for visualizations
first_case_count = df.groupby('country').first().confirmed.reset_index()
date_anchor = predictionsDF_filtered[['country', 'days_since_100']].drop_duplicates()
max_pred = predictionsDF_filtered.groupby('country').max()[['predictions']].reset_index()
benchmark = (date_anchor
.merge(first_case_count, on='country', how='left')
.merge(max_pred, on='country', how='left')
)
benchmark['benchmark'] = benchmark.apply(lambda x: x.confirmed * (1.3**(x.days_since_100)),
axis=1)
benchmarkDF_filtered = benchmark.loc[lambda x: x.benchmark <= x.predictions]
# Compute the last known total confirmed case, which is the black dot at the end of the red line in the viz
lastpointDF = df.groupby('country').last().reset_index()
# DataFrame of Chart Titles by country. This a enables a hack to allow Altiar to switch values
titleDF = lastpointDF[['country']]
titleDF['title'] = titleDF.apply(lambda x: x.country + ': Actual vs. Predicted Growth',
axis=1)
/opt/hostedtoolcache/Python/3.6.10/x64/lib/python3.6/site-packages/ipykernel_launcher.py:55: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
#hide_input
##################################
#### Construct The Altair Viz ####
##################################
alt.data_transformers.disable_max_rows()
selectCountry = alt.selection_single(
name='Select',
fields=['country'],
init={'country': 'US'},
bind=alt.binding_select(options=sorted(countries.tolist()))
)
##### Model Predictions (Grey) #####
width = 275
height= 250
pred = (alt.Chart(predictionsDF_filtered)
.mark_line(opacity=.15)
.encode(x=alt.X('days_since_100:Q', axis=alt.Axis(title='Days since 100th confirmed case')),
y=alt.Y('predictions:Q',
axis=alt.Axis(title='Confirmed cases')),
color=alt.Color('pred_idx:Q', legend=None, scale=None),)
.transform_filter(selectCountry)
).properties(
width=width,
height=height
)
predlog = (alt.Chart(predictionsDF_filtered)
.mark_line(opacity=.15)
.encode(x=alt.X('days_since_100:Q', axis=alt.Axis(title='Days since 100th confirmed case')),
y=alt.Y('predictions:Q',
axis=alt.Axis(title=None),
scale=alt.Scale(type='log', base=10)),
color=alt.Color('pred_idx:Q', legend=None, scale=None),)
.transform_filter(selectCountry)
).properties(
width=width,
height=height
)
##### Mark The Last Case Count #####
# Point
last_point = (alt.Chart(lastpointDF)
.mark_circle(color="black", size=40)
.encode(x='days_since_100:Q',
y='confirmed:Q')
.transform_filter(selectCountry)
)
# Label
last_point_label = (alt.Chart(lastpointDF)
.mark_text(align='right', dx=-10, dy=-15, fontSize=15)
.encode(x='days_since_100:Q',
y='confirmed:Q',
text='confirmed')
.transform_filter(selectCountry)
)
##### Place 133% Dotted Line Reference On Graph #####
guide = (alt.Chart(benchmarkDF_filtered)
.mark_line(color='black', opacity=.5, strokeDash=[3,3])
.encode(x='days_since_100:Q',
y='benchmark:Q',
)
.transform_filter(selectCountry)
)
##### Dynamic Chart Title
title_main = alt.Chart(titleDF).mark_text(dy=-15, dx=325, size=20).encode(
text='title:N'
).transform_filter(selectCountry)
title_linear = (alt.Chart(alt.Data(values=[{'title': 'Y axis is on a linear scale'}]))
.mark_text(dy=-150, size=15)
.encode(text='title:N')
)
title_log = (alt.Chart(alt.Data(values=[{'title': 'Y axis is on a log scale'}]))
.mark_text(dy=-150, size=15)
.encode(text='title:N')
)
###### Legend (Hacked)
source = pd.DataFrame.from_records([{"img": "https://covid19dashboards.com/images/covid-bayes-growth-legend.png"}])
legend = (alt.Chart(source)
.mark_image(dy=-150,
width=200,
height=150)
.encode(url='img')
)
##### Actual Cases (Red) #####
actual = (alt.Chart(df).mark_line(color="red")
.encode(x='days_since_100:Q',
y='confirmed:Q')
.transform_filter(selectCountry)
)
annotations = last_point + last_point_label + guide + actual
linear_chart = pred.add_selection(selectCountry) + annotations + title_linear
log_chart = predlog + annotations + title_log + title_log
##### Layer All Charts Together And Configure Formatting #####
(
((title_main + legend) & ( linear_chart | log_chart ))
.configure_title(fontSize=20)
.configure_axis(labelFontSize=15,titleFontSize=18, grid=False)
)