import pandas as pd
import numpy as np
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from sklearn.linear_model import LinearRegression
from scipy import stats
import statsmodels.api as sm
import pylab
# from google.colab import files
# from io import StringIO
# uploaded = files.upload()
from io import BytesIO
from zipfile import ZipFile
import pandas
import requests
url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/us_total.csv'
data = pd.read_csv(url)
data.Date = pd.to_datetime(data.Date)
# for fancy python printing
from IPython.display import Markdown, display
def printmd(string):
display(Markdown(string))
import warnings
warnings.filterwarnings('ignore')
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 250
/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead. import pandas.util.testing as tm
data["Country/Region"]='US'
data
Date | Confirmed | Deaths | Recovered | Active | New cases | New deaths | New recovered | Country/Region | |
---|---|---|---|---|---|---|---|---|---|
0 | 2020-03-22 | 33918.0 | 435.0 | 0.0 | 0.0 | 9836.0 | 142.0 | 0.0 | US |
1 | 2020-03-23 | 43754.0 | 577.0 | 0.0 | 0.0 | 10122.0 | 140.0 | 348.0 | US |
2 | 2020-03-24 | 53876.0 | 717.0 | 348.0 | 0.0 | 12085.0 | 239.0 | 13.0 | US |
3 | 2020-03-25 | 65961.0 | 956.0 | 361.0 | 0.0 | 18012.0 | 269.0 | 320.0 | US |
4 | 2020-03-26 | 83973.0 | 1225.0 | 681.0 | 0.0 | 17894.0 | 378.0 | 188.0 | US |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
349 | 2021-03-06 | 28952970.0 | 524362.0 | 0.0 | 28429067.0 | 40903.0 | 671.0 | 0.0 | US |
350 | 2021-03-07 | 28993873.0 | 525033.0 | 0.0 | 0.0 | 44758.0 | 719.0 | 0.0 | US |
351 | 2021-03-08 | 29038631.0 | 525752.0 | 0.0 | 0.0 | 57417.0 | 1947.0 | 0.0 | US |
352 | 2021-03-09 | 29096048.0 | 527699.0 | 0.0 | 0.0 | 57667.0 | 1494.0 | 0.0 | US |
353 | 2021-03-10 | 29153715.0 | 529193.0 | 0.0 | 0.0 | NaN | NaN | NaN | US |
354 rows × 9 columns
# function to make the time series of confirmed and daily confirmed cases for a specific country
def create_country (country, end_date, state = False) :
if state :
df = data.loc[data["Province/State"] == country, ["Province/State", "Date", "Confirmed", "Deaths", "Recovered"]]
else :
df = data.loc[data["Country/Region"] == country, ["Country/Region", "Date", "Confirmed", "Deaths", "Recovered"]]
df.columns = ["country", "date", "confirmed", "deaths", "recovered"]
# group by country and date, sum(confirmed, deaths, recovered). do this because countries have multiple cities
df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered'].sum().reset_index()
# # convert date string to datetime
# std_dateparser = lambda x: str(x)[5:10]
# df.date = pd.to_datetime(df.date)
# df['date_only'] = df.date.apply(std_dateparser)
df = df.sort_values(by = "date")
df = df[df.date <= end_date]
# make new confirmed cases every day:
cases_shifted = np.array([0] + list(df.confirmed[:-1]))
daily_confirmed = np.array(df.confirmed) - cases_shifted
df["daily_confirmed"] = daily_confirmed
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 6))
ax = [ax]
sns.lineplot(x = df.date,
y = df.daily_confirmed,
ax = ax[0])
ax[0].set(ylabel='Daily Confirmed Cases')
ax[0].axvline(pd.to_datetime('2020-12-14'),
linestyle = '--', linewidth = 1.5,
label = "Start of Vaccination: Dec 14, 2020" ,
color = "red")
ax[0].xaxis.get_label().set_fontsize(16)
ax[0].yaxis.get_label().set_fontsize(16)
ax[0].title.set_fontsize(20)
ax[0].tick_params(labelsize=16)
myFmt = mdates.DateFormatter('%b %-d')
ax[0].xaxis.set_major_formatter(myFmt)
ax[0].set(ylabel='Daily Confirmed Cases', xlabel='');
ax[0].legend(loc = "bottom right", fontsize=12.8)
sns.set_style("ticks")
plt.tight_layout()
sns.despine()
plt.savefig('/content/sample_data/us_daily.pdf')
print(df.tail())
return df
def summary(samples):
site_stats = {}
for k, v in samples.items():
site_stats[k] = {
"mean": torch.mean(v, 0),
"std": torch.std(v, 0),
"5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
"95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
}
return site_stats
cad = create_country("US", end_date = "2021-03-10")
country date confirmed deaths recovered daily_confirmed 349 US 2021-03-06 28952970.0 524362.0 0.0 58062.0 350 US 2021-03-07 28993873.0 525033.0 0.0 40903.0 351 US 2021-03-08 29038631.0 525752.0 0.0 44758.0 352 US 2021-03-09 29096048.0 527699.0 0.0 57417.0 353 US 2021-03-10 29153715.0 529193.0 0.0 57667.0
cad_start = "2020-11-01" # 13 confirmed cases
cad = cad[cad.date >= cad_start].reset_index(drop = True)
cad["days_since_start"] = np.arange(cad.shape[0]) + 1
cad.head()
country | date | confirmed | deaths | recovered | daily_confirmed | days_since_start | |
---|---|---|---|---|---|---|---|
0 | US | 2020-11-01 | 9291068.0 | 231776.0 | 3630579.0 | 104899.0 | 1 |
1 | US | 2020-11-02 | 9376816.0 | 232312.0 | 3674981.0 | 85748.0 | 2 |
2 | US | 2020-11-03 | 9504162.0 | 233905.0 | 3705130.0 | 127346.0 | 3 |
3 | US | 2020-11-04 | 9607933.0 | 235055.0 | 3743527.0 | 103771.0 | 4 |
4 | US | 2020-11-05 | 9737135.0 | 236217.0 | 3781751.0 | 129202.0 | 5 |
cad.shape
cad_tmp = cad
cad_tmp.shape
(130, 7)
# variable for data to easily swap it out:
country_ = "US"
reg_data = cad_tmp.copy()
reg_data.head()
country | date | confirmed | deaths | recovered | daily_confirmed | days_since_start | |
---|---|---|---|---|---|---|---|
0 | US | 2020-11-01 | 9291068.0 | 231776.0 | 3630579.0 | 104899.0 | 1 |
1 | US | 2020-11-02 | 9376816.0 | 232312.0 | 3674981.0 | 85748.0 | 2 |
2 | US | 2020-11-03 | 9504162.0 | 233905.0 | 3705130.0 | 127346.0 | 3 |
3 | US | 2020-11-04 | 9607933.0 | 235055.0 | 3743527.0 | 103771.0 | 4 |
4 | US | 2020-11-05 | 9737135.0 | 236217.0 | 3781751.0 | 129202.0 | 5 |
reg_data.shape
(130, 7)
!pip install pyro-ppl
!pip install numpyro
Collecting pyro-ppl Downloading https://files.pythonhosted.org/packages/aa/7a/fbab572fd385154a0c07b0fa138683aa52e14603bb83d37b198e5f9269b1/pyro_ppl-1.6.0-py3-none-any.whl (634kB) |████████████████████████████████| 634kB 4.3MB/s Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (3.3.0) Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.8.1+cu101) Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.19.5) Collecting pyro-api>=0.1.1 Downloading https://files.pythonhosted.org/packages/fc/81/957ae78e6398460a7230b0eb9b8f1cb954c5e913e868e48d89324c68cec7/pyro_api-0.1.2-py3-none-any.whl Requirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (4.41.1) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.8.0->pyro-ppl) (3.7.4.3) Installing collected packages: pyro-api, pyro-ppl Successfully installed pyro-api-0.1.2 pyro-ppl-1.6.0 Collecting numpyro Downloading https://files.pythonhosted.org/packages/00/a6/064eedcec968207259acf06cf156c0ea9a6534328bdf7da0e768cfdb3239/numpyro-0.6.0-py3-none-any.whl (218kB) |████████████████████████████████| 225kB 5.4MB/s Collecting jax==0.2.10 Downloading https://files.pythonhosted.org/packages/88/9d/2862825b5eddd0df64c78b22cc0b897f0128b1c6494bf39e4849e9e0fade/jax-0.2.10.tar.gz (589kB) |████████████████████████████████| 593kB 7.1MB/s Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.41.1) Collecting jaxlib==0.1.62 Downloading https://files.pythonhosted.org/packages/7e/75/30f1c643b7edb1309b6d748809042241737fe43127cb41754266eca79250/jaxlib-0.1.62-cp37-none-manylinux2010_x86_64.whl (35.7MB) |████████████████████████████████| 35.7MB 98kB/s Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from jax==0.2.10->numpyro) (1.19.5) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.10->numpyro) (0.12.0) Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.10->numpyro) (3.3.0) Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (1.12) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (1.4.1) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.10->numpyro) (1.15.0) Building wheels for collected packages: jax Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.10-cp37-none-any.whl size=679776 sha256=122fccd699209761ec50b95c9b10d0797db69830a250ca111d80d085f5371673 Stored in directory: /root/.cache/pip/wheels/44/ea/ac/3be3bc19ee3b62f6fe1561eb6df1199284bb6bab819c1befa4 Successfully built jax Installing collected packages: jax, jaxlib, numpyro Found existing installation: jax 0.2.12 Uninstalling jax-0.2.12: Successfully uninstalled jax-0.2.12 Found existing installation: jaxlib 0.1.65+cuda110 Uninstalling jaxlib-0.1.65+cuda110: Successfully uninstalled jaxlib-0.1.65+cuda110 Successfully installed jax-0.2.10 jaxlib-0.1.62 numpyro-0.6.0
import torch
import pyro
import pyro.distributions as dist
from torch import nn
from pyro.nn import PyroModule, PyroSample
from pyro.infer import MCMC, NUTS, HMC
from pyro.infer.autoguide import AutoGuide, AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.infer import Predictive
# we should be able to have an empirical estimate for the mean of the prior for the 2nd regression bias term
# this will be something like b = log(max(daily_confirmed))
# might be able to have 1 regression model but change the data so that we have new terms for (tau < t)
# like an interaction term
class COVID_change(PyroModule):
def __init__(self, in_features, out_features, b1_mu, b2_mu):
super().__init__()
self.linear1 = PyroModule[nn.Linear](in_features, out_features, bias = False)
self.linear1.weight = PyroSample(dist.Normal(0.5, 0.25).expand([1, 1]).to_event(1))
self.linear1.bias = PyroSample(dist.Normal(b1_mu, 1.))
# could possibly have stronger priors for the 2nd regression line, because we wont have as much data
self.linear2 = PyroModule[nn.Linear](in_features, out_features, bias = False)
self.linear2.weight = PyroSample(dist.Normal(0., 0.25).expand([1, 1])) #.to_event(1))
self.linear2.bias = PyroSample(dist.Normal(b2_mu, b2_mu/4))
def forward(self, x, y=None):
tau = pyro.sample("tau", dist.Beta(4, 3))
sigma = pyro.sample("sigma", dist.Uniform(0., 3.))
# fit lm's to data based on tau
sep = int(np.ceil(tau.detach().numpy() * len(x)))
mean1 = self.linear1(x[:sep]).squeeze(-1)
mean2 = self.linear2(x[sep:]).squeeze(-1)
mean = torch.cat((mean1, mean2))
obs = pyro.sample("obs", dist.StudentT(2, mean, sigma), obs=y)
return mean
tensor_data = torch.tensor(reg_data[["confirmed", "days_since_start", "daily_confirmed"]].values, dtype=torch.float)
x_data = tensor_data[:, 1].unsqueeze_(1)
y_data = np.log(tensor_data[:, 0])
y_data_daily = np.log(tensor_data[:, 2])
# prior hyper params
# take log of the average of the 1st quartile to get the prior mean for the bias of the 2nd regression line
q1 = np.quantile(y_data, q = 0.25)
bias_1_mean = np.mean(y_data.numpy()[y_data <= q1])
print("Prior mean for Bias 1: ", bias_1_mean)
# take log of the average of the 4th quartile to get the prior mean for the bias of the 2nd regression line
q4 = np.quantile(y_data, q = 0.75)
bias_2_mean = np.mean(y_data.numpy()[y_data >= q4])
print("Prior mean for Bias 2: ", bias_2_mean)
Prior mean for Bias 1: 16.255508 Prior mean for Bias 2: 17.152462
model = COVID_change(1, 1,
b1_mu = bias_1_mean,
b2_mu = bias_2_mean)
# need more than 400 samples/chain if we want to use a flat prior on b_2 and w_2
num_samples = 400
# mcmc
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel,
num_samples=num_samples,
warmup_steps = 200,
num_chains = 1)
mcmc.run(x_data, y_data)
samples = mcmc.get_samples()
Sample: 100%|██████████| 600/600 [20:17, 2.03s/it, step size=1.81e-04, acc. prob=0.898]
# Save the model:
import dill
# with open('us.pkl', 'wb') as f:
# dill.dump(mcmc, f)
with open('us.pkl', 'rb') as f:
mcmc = dill.load(f)
samples = mcmc.get_samples()
# extract individual posteriors
weight_1_post = samples["linear1.weight"].detach().numpy()
weight_2_post = samples["linear2.weight"].detach().numpy()
bias_1_post = samples["linear1.bias"].detach().numpy()
bias_2_post = samples["linear2.bias"].detach().numpy()
tau_post = samples["tau"].detach().numpy()
sigma_post = samples["sigma"].detach().numpy()
# build likelihood distribution:
tau_days = list(map(int, np.ceil(tau_post * len(x_data))))
mean_ = torch.zeros(len(tau_days), len(x_data))
obs_ = torch.zeros(len(tau_days), len(x_data))
for i in range(len(tau_days)) :
mean_[i, :] = torch.cat((x_data[:tau_days[i]] * weight_1_post[i] + bias_1_post[i],
x_data[tau_days[i]:] * weight_2_post[i] + bias_2_post[i])).reshape(len(x_data))
obs_[i, :] = dist.Normal(mean_[i, :], sigma_post[i]).sample()
samples["_RETURN"] = mean_
samples["obs"] = obs_
pred_summary = summary(samples)
mu = pred_summary["_RETURN"] # mean
y = pred_summary["obs"] # samples from likelihood: mu + sigma
y_shift = np.exp(y["mean"]) - np.exp(torch.cat((y["mean"][0:1], y["mean"][:-1])))
print(y_shift)
predictions = pd.DataFrame({
"days_since_start": x_data[:, 0],
"mu_mean": mu["mean"], # mean of likelihood
"mu_perc_5": mu["5%"],
"mu_perc_95": mu["95%"],
"y_mean": y["mean"], # mean of likelihood + noise
"y_perc_5": y["5%"],
"y_perc_95": y["95%"],
"true_confirmed": y_data,
"true_daily_confirmed": y_data_daily,
"y_daily_mean": y_shift
})
w1_ = pred_summary["linear1.weight"]
w2_ = pred_summary["linear2.weight"]
b1_ = pred_summary["linear1.bias"]
b2_ = pred_summary["linear2.bias"]
tau_ = pred_summary["tau"]
sigma_ = pred_summary["sigma"]
ind = int(np.ceil(tau_["mean"] * len(x_data)))
tensor([ 0., 123758., 107474., 143517., 139824., 111102., 117757., 137822., 119536., 136832., 141099., 125632., 175256., 113747., 119790., 164404., 142176., 151333., 157725., 127663., 160600., 153363., 152496., 178360., 128725., 172123., 178194., 148925., 201997., 154104., 196959., 157244., 172678., 197221., 160619., 200865., 142596., 213519., 195834., 187626., 185048., 212983., 186515., 217653., 201873., 221126., 184906., 212976., 253842., 221686., 211416., 246752., 218192., 241416., 244118., 237014., 196282., 281720., 252228., 231080., 246670., 251356., 250950., 276856., 223894., 295550., 289356., 285614., 251826., 293064., 261672., 343500., 267660., 278670., 377790., 217006., 275918., 182914., 169854., 108930., 89022., 93620., 105324., 105762., 90956., 95094., 80826., 53410., 121902., 12100., 101208., 46104., 102032., 127176., 65978., 119180., 67914., 65600., 108964., 58170., 135956., 101606., 106796., 103274., 36088., 135646., 13788., 100968., 68676., 110404., 152796., 79136., 36322., 170090., 70698., 113806., 81500., 60522., 133360., 127798., 83650., 146232., 28538., 130194., 33138., 123618., 120502., 108912., 106778., 96290.])
mcmc.summary()
diag = mcmc.diagnostics()
mean std median 5.0% 95.0% n_eff r_hat tau 0.60 0.03 0.59 0.55 0.64 5.89 1.31 sigma 0.02 0.00 0.02 0.01 0.03 2.75 2.24 linear1.weight[0,0] 0.01 0.00 0.01 0.01 0.01 5.15 1.57 linear1.bias 16.05 0.01 16.05 16.03 16.06 5.46 1.59 linear2.weight[0,0] 0.00 0.00 0.00 0.00 0.00 16.96 1.05 linear2.bias 16.76 0.03 16.76 16.72 16.81 15.75 1.06 Number of divergences: 0
print(ind)
print(reg_data.date[ind])
sns.distplot(weight_1_post,
kde_kws = {"label": "Weight posterior before CP"},
color = "red",
norm_hist = True,
kde = True)
plt.axvline(x = w1_["mean"], linestyle = '--',label = "Mean weight before CP" ,
color = "red")
sns.distplot(weight_2_post,
kde_kws = {"label": "Weight posterior after CP"},
color = "teal",
norm_hist = True,
kde = True)
plt.axvline(x = w2_["mean"], linestyle = '--',label = "Mean weight after CP" ,
color = "teal")
legend = plt.legend(loc='upper right')
legend.get_frame().set_alpha(0.7)
sns.set_style("ticks")
plt.tight_layout()
sns.despine()
plt.savefig('/content/sample_data/us_weights.pdf')
78 2021-01-18 00:00:00
print(w1_["mean"])
print(w2_["mean"])
tensor([[0.0126]]) tensor([[0.0034]])
start_date_ = str(reg_data.date[0]).split(' ')[0]
change_date_ = str(reg_data.date[ind]).split(' ')[0]
print("Date of change for {}: {}".format(country_, change_date_))
import seaborn as sns
# plot data:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
ax = [ax]
# log regression model
ax[0].scatter(y = np.exp(y_data[:ind]), x = x_data[:ind], s = 15);
ax[0].scatter(y = np.exp(y_data[ind:]), x = x_data[ind:], s = 15, color = "red");
ax[0].plot(predictions["days_since_start"],
np.exp(predictions["y_mean"]),
color = "green",
label = "Fitted line by MCMC-NUTS model")
ax[0].axvline(43,
linestyle = '--', linewidth = 1.5,
label = "Date of Vaccination: Dec 14, 2020" ,
color = "red")
ax[0].axvline(ind,
linestyle = '--', linewidth = 1.5,
label = "Date of Change: Jan 18, 2021",
color = "black")
ax[0].fill_between(predictions["days_since_start"],
np.exp(predictions["y_perc_5"]),
np.exp(predictions["y_perc_95"]),
alpha = 0.25,
label = "90% CI of predictions",
color = "teal");
ax[0].fill_betweenx([0, 1],
tau_["5%"] * len(x_data),
tau_["95%"] * len(x_data),
alpha = 0.25,
label = "90% CI of changing point",
color = "lightcoral",
transform=ax[0].get_xaxis_transform());
ax[0].set(ylabel = "Total Cases",)
# xlabel = "Days since %s" % start_date_,
# title = "Confirmed Cases in China") /
ax[0].legend(loc = "lower right", fontsize=12.8)
ax[0].set_ylim([8000000,40000000])
ax[0].set_xlim([35,130])
ax[0].xaxis.get_label().set_fontsize(16)
ax[0].yaxis.get_label().set_fontsize(16)
ax[0].title.set_fontsize(20)
ax[0].tick_params(labelsize=16)
plt.xticks(ticks=[43, 60,78,99, 120], labels=["Dec 14",
"Dec 31",
"Jan 18",
"Feb 8", "Mar 1"], fontsize=15)
ax[0].set_yscale('log')
plt.setp(ax[0].get_xticklabels(), rotation=0, horizontalalignment='center')
print(reg_data.columns)
myFmt = mdates.DateFormatter('%m-%d')
sns.set_style("ticks")
sns.despine()
plt.tight_layout()
ax[0].figure.savefig('/content/sample_data/us_cp.pdf')
Date of change for US: 2021-01-18 Index(['country', 'date', 'confirmed', 'deaths', 'recovered', 'daily_confirmed', 'days_since_start'], dtype='object')
fig, ax = plt.subplots(1,3, figsize=(15, 6))
#plt.figure(figsize=(11, 5))
sns.lineplot(x="date",
y="confirmed",
data= reg_data,
ax = ax[0]
).set_title("Confirmed COVID-19 Cases in %s" % country_)
ax[0].axvline(reg_data.date[ind], color="red", linestyle="--")
ax[1].scatter(y = reg_data.confirmed[:ind], x = x_data[:ind], s = 15);
ax[1].scatter(y = reg_data.confirmed[ind:], x = x_data[ind:], s = 15, color = "red");
ax[1].plot(predictions["days_since_start"],
np.exp(predictions["y_mean"]),
color = "green",
label = "Mean")
ax[1].axvline(ind, linestyle = '--',
linewidth = 1,
label = "Day of Change")
ax[1].legend(loc = "upper left")
ax[1].set(ylabel = "Confirmed Cases",
xlabel = "Days since %s" % start_date_,
title = "Confirmed Cases: %s" % country_);
ax[2].scatter(y = reg_data.daily_confirmed[:ind], x = x_data[:ind], s = 15);
ax[2].scatter(y = reg_data.daily_confirmed[ind:], x = x_data[ind:], s = 15, color = "red");
ax[2].plot(predictions["days_since_start"],
predictions["y_daily_mean"],
color = "green",
label = "Mean")
ax[2].axvline(ind, linestyle = '--',
linewidth = 1,
label = "Day of Change")
ax[2].legend(loc = "upper left")
ax[2].set(ylabel = "Daily Confirmed Cases",
xlabel = "Days since %s" % start_date_,
title = "Daily Confirmed Cases: %s" % country_);
printmd("**Date of change for {}: {}**".format(country_, change_date_));
import matplotlib.dates as mdates
myFmt = mdates.DateFormatter('%m-%d')
ax[0].xaxis.set_major_formatter(myFmt)
# ax[0].set_xticklabels(reg_data.date, rotation = 45, fontsize="10", va="center")
plt.setp(ax[0].get_xticklabels(), rotation=30, horizontalalignment='right')
ax[0].set(ylabel='Confirmed Cases', xlabel='Date');
plt.tight_layout()
plt.savefig('/content/sample_data/us_mean.pdf')
Date of change for US: 2021-01-18