This is based on COVID-19 growth prediction using multivariate long short term memory
by Novanto Yudistira
https://arxiv.org/pdf/2005.04809.pdf
https://github.com/VICS-CORE/lstmcorona/blob/master/lstm.py
import pandas as pd
import numpy as np
import requests as rq
import datetime as dt
import torch
import json
tnn = torch.nn
top = torch.optim
from torch.utils import data as tdt
from matplotlib.ticker import MultipleLocator
from matplotlib.dates import DayLocator, AutoDateLocator, ConciseDateFormatter
%matplotlib inline
CUDA="cuda:0"
CPU="cpu"
if torch.cuda.is_available():
device = torch.device(CUDA)
cd = torch.cuda.current_device()
print("Num devices:", torch.cuda.device_count())
print("Current device:", cd)
print("Device name:", torch.cuda.get_device_name(cd))
print("Device props:", torch.cuda.get_device_properties(cd))
print(torch.cuda.memory_summary(cd))
else:
device = torch.device(CPU)
print(device)
# define paths
DATA_DIR = 'data'
MODELS_DIR = 'models'
from google.colab import drive
drive.mount('/content/drive')
%cd 'drive/My Drive/CS/colab/'
!cat /proc/cpuinfo
!cat /proc/meminfo
!curl https://covid.ourworldindata.org/data/owid-covid-data.csv --output data/owid-covid-data.csv
!head -n1 data/owid-covid-data.csv
cols = ['location', 'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'population']
dates = ['date']
df = pd.read_csv(DATA_DIR + "/owid-covid-data.csv",
usecols=cols,
parse_dates=dates)
df.sample()
class YudistirNet(tnn.Module):
def __init__(self, ip_seq_len=1, op_seq_len=1, hidden_size=1, num_layers=1):
super(YudistirNet, self).__init__()
self.ip_seq_len = ip_seq_len
self.op_seq_len = op_seq_len
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = tnn.LSTM(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True)
self.linear = tnn.Linear(self.hidden_size * self.ip_seq_len, self.op_seq_len)
self.sigmoid = tnn.Sigmoid()
def forward(self, ip):
lstm_out, _ = self.lstm(ip)
linear_out = self.linear(lstm_out.reshape(-1, self.hidden_size * self.ip_seq_len))
sigmoid_out = self.sigmoid(linear_out.view(-1, self.op_seq_len))
return sigmoid_out
def predict(self, ip):
with torch.no_grad():
preds = self.forward(ip)
return preds
def save_checkpoint(epoch, model, optimizer, trn_losses, val_losses, min_val_loss, path=""):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'trn_losses': trn_losses,
'val_losses': val_losses,
'min_val_loss': min_val_loss
}, path or MODELS_DIR + "/latest.pt")
print("Checkpoint saved")
def load_checkpoint(path="", device="cpu"):
cp = torch.load(path or MODELS_DIR + "/latest.pt", map_location=device)
print("Checkpoint loaded")
return cp['epoch'], cp['model_state_dict'], cp['optimizer_state_dict'], cp['trn_losses'], cp['val_losses'], cp.get('min_val_loss', np.Inf)
# config
IP_SEQ_LEN = 40
OP_SEQ_LEN = 20
BATCH_SIZE = 1
VAL_RATIO = 0.3
HIDDEN_SIZE = 20
NUM_LAYERS = 4
LEARNING_RATE = 0.001
NUM_EPOCHS = 3001
# to continue training on another model, set resume to true
RESUME = False
model = YudistirNet(ip_seq_len=IP_SEQ_LEN, op_seq_len=OP_SEQ_LEN, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
model = model.to(device)
loss_fn = tnn.MSELoss()
optimizer = top.Adam(model.parameters(), lr=LEARNING_RATE)
sum(p.numel() for p in model.parameters() if p.requires_grad)
def gen_dataset():
ip_trn = []
op_trn = []
countries = df['location'].unique()
pop_countries = ['China', 'United States', 'Indonesia', 'Pakistan', 'Brazil', 'Bangladesh', 'Russia', 'Mexico']
c = 0
for country in countries:
if country in ['World', 'International', 'India']: # Countries to be skipped
continue
country_df = df.loc[df.location == country]
tot_cases_gt_100 = (country_df['total_cases'] >= 100)
country_df = country_df.loc[tot_cases_gt_100]
if len(country_df) >= IP_SEQ_LEN + OP_SEQ_LEN:
c += 1
pop = country_df['population'].iloc[0]
print(c, country, len(country_df), pop)
daily_cases = np.array(country_df['new_cases'].rolling(7, center=True, min_periods=1).mean() * 1000 / pop, dtype=np.float32)
for i in range(len(country_df) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
ip_trn.append(daily_cases[i : i+IP_SEQ_LEN])
op_trn.append(daily_cases[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])
ip_trn = torch.from_numpy(np.array(ip_trn, dtype=np.float32))
op_trn = torch.from_numpy(np.array(op_trn, dtype=np.float32))
dataset = tdt.TensorDataset(ip_trn, op_trn)
val_len = int(VAL_RATIO * len(dataset))
trn_len = len(dataset) - val_len
trn_set, val_set = tdt.random_split(dataset, (trn_len, val_len))
return trn_set, val_set
try:
ds = torch.load(DATA_DIR + '/ds.pt')
trn_set, val_set = ds['trn'], ds['val']
print("Loaded dataset from ds.pt")
except FileNotFoundError:
trn_set, val_set = gen_dataset()
torch.save({'trn': trn_set, 'val': val_set}, DATA_DIR + '/ds.pt')
print("Saved dataset to ds.pt")
finally:
print("Training data:", len(trn_set), "Validation data:", len(val_set))
trn_loader = tdt.DataLoader(trn_set, shuffle=True, batch_size=BATCH_SIZE)
val_loader = tdt.DataLoader(val_set, shuffle=True, batch_size=BATCH_SIZE)
trn_loss_vals = []
val_loss_vals = []
e = 0
min_val_loss = np.Inf
if RESUME:
e, model_dict, optimizer_dict, trn_loss_vals, val_loss_vals, min_val_loss = load_checkpoint(device=device)
e+=1
model.load_state_dict(model_dict)
optimizer.load_state_dict(optimizer_dict)
# TRAIN
print("BEGIN: [", dt.datetime.now(), "]")
while e < NUM_EPOCHS:
model.train()
trn_losses = []
for data in trn_loader:
ip, op = data
ip = ip.to(device)
op = op.to(device)
optimizer.zero_grad() # set grads to 0
preds = model(ip.view(-1, IP_SEQ_LEN, 1)) # predict
loss = loss_fn(preds, op.view(-1, OP_SEQ_LEN)) # calc loss
loss.backward() # calc and assign grads
optimizer.step() # update weights
trn_losses.append(loss) # logging
avg_trn_loss = torch.stack(trn_losses).mean().item() * 10000
trn_loss_vals.append(avg_trn_loss)
model.eval()
with torch.no_grad():
val_losses = []
for data in val_loader:
ip, op = data
ip = ip.to(device)
op = op.to(device)
preds = model(ip.view(-1, IP_SEQ_LEN, 1))
loss = loss_fn(preds, op.view(-1, OP_SEQ_LEN))
val_losses.append(loss)
avg_val_loss = torch.stack(val_losses).mean().item() * 10000
val_loss_vals.append(avg_val_loss)
if e%10==0:
print("[", dt.datetime.now(), "] epoch:", f"{e:3}", "avg_val_loss:", f"{avg_val_loss: .5f}", "avg_trn_loss:", f"{avg_trn_loss: .5f}")
if e%100==0:
save_checkpoint(e, model, optimizer, trn_loss_vals, val_loss_vals, min_val_loss, MODELS_DIR + "/latest-e" + str(e) + ".pt")
if avg_val_loss <= min_val_loss:
min_val_loss = avg_val_loss
save_checkpoint(e, model, optimizer, trn_loss_vals, val_loss_vals, min_val_loss, MODELS_DIR + "/best-e" + str(e) + ".pt")
e+=1
print("END: [", dt.datetime.now(), "]")
# model_path = MODELS_DIR + "/IP20_OP10_H10_L4_E2001_LR001.pt"
model_path = "/home/mayank/Downloads/ds4020-e17xx.pt"#ds4020-0612-e50x.pt"
e, md, _, trn_loss_vals, val_loss_vals, _ = load_checkpoint(model_path, device=device)
print(e)
model.load_state_dict(md)
model.eval()
df_loss = pd.DataFrame({
'trn_loss': trn_loss_vals,
'val_loss': val_loss_vals
})
df_loss['trn_loss'] = df_loss['trn_loss'].rolling(10).mean()
df_loss['val_loss'] = df_loss['val_loss'].rolling(10).mean()
_ = df_loss.plot(
y=['trn_loss', 'val_loss'],
title='Loss per epoch',
subplots=True,
figsize=(5,6),
sharex=False,
logy=True
)
c = "India"
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000
all_preds = []
pred_vals = []
out_vals = []
test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)
for i in range(len(test_data) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
ip = torch.tensor(test_data[i : i+IP_SEQ_LEN])
op = torch.tensor(test_data[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])
ip = ip.to(device)
op = op.to(device)
pred = model.predict(ip.view(1, IP_SEQ_LEN, 1))
if i==0: # prepend first input
out_vals.extend(ip.view(IP_SEQ_LEN).cpu().numpy() * pop_fct)
pred_vals.extend([np.NaN] * IP_SEQ_LEN)
all_preds.append(pred.view(OP_SEQ_LEN).cpu().numpy() * pop_fct)
pred_vals.append(pred.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)
out_vals.append(op.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)
# last N-1 values
out_vals.extend(op.view(OP_SEQ_LEN).cpu().numpy()[1:] * pop_fct)
pred_vals.extend(([np.NaN] * OP_SEQ_LEN)[1:]) # pad with NaN
cmp_df = pd.DataFrame({
'actual': out_vals,
'predicted0': pred_vals
})
# set date
start_date = df.loc[(df.location==c) & (df.total_cases>=100)]['date'].iloc[0]
end_date = start_date + dt.timedelta(days=cmp_df.index[-1])
cmp_df['Date'] = pd.Series([start_date + dt.timedelta(days=i) for i in range(len(cmp_df))])
# plot noodles
ax=None
i=IP_SEQ_LEN
mape=[]
for pred in all_preds:
cmp_df['predicted_cases'] = np.NaN
cmp_df.loc[i:i+OP_SEQ_LEN-1, 'predicted_cases'] = pred
ax = cmp_df.plot(x='Date', y='predicted_cases', ax=ax, legend=False)
ape = np.array(100 * ((cmp_df['actual'] - cmp_df['predicted_cases']).abs() / cmp_df['actual']))
# mape.append(ape.mean())
mape.append(ape[~np.isnan(ape)])
i+=1
total = np.zeros(OP_SEQ_LEN)
for m in mape:
total += m
elwise_mape = total / len(mape)
print("Day wise accuracy:", 100 - elwise_mape)
acc = f"{100 - sum(elwise_mape)/len(elwise_mape):0.2f}%"
# acc = f"{100 - sum(mape)/len(mape):0.2f}%"
# plot primary lines
ax = cmp_df.plot(
x='Date',
y=['actual', 'predicted0'],
figsize=(20,8),
lw=5,
title=c + ' | Daily predictions | ' + acc,
ax=ax
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)
c = "India"
n_days_prediction = 200
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000
test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)
in_data = test_data[-IP_SEQ_LEN:]
out_data = np.array([], dtype=np.float32)
for i in range(int(n_days_prediction / OP_SEQ_LEN)):
ip = torch.tensor(
in_data,
dtype=torch.float32
)
ip = ip.to(device)
pred = model.predict(ip.view(1, IP_SEQ_LEN, 1))
in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:], pred.cpu().numpy())
out_data = np.append(out_data, pred.cpu().numpy())
orig_df = pd.DataFrame({
'actual': test_data * pop_fct
})
fut_df = pd.DataFrame({
'predicted': out_data * pop_fct
})
# print(fut_df['predicted'].astype('int').to_csv(sep='|', index=False))
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
orig_df['total'] = (orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)).cumsum()
start_date = df.loc[(df.location==c) & (df.total_cases>=100)]['date'].iloc[0]
orig_df['Date'] = pd.Series([start_date + dt.timedelta(days=i) for i in range(len(orig_df))])
ax = orig_df.plot(
x='Date',
y=['actual', 'predicted'],
title=c + ' daily cases',
figsize=(10,6),
grid=True
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)
# orig_df['total'] = orig_df['total'].astype('int')
# orig_df['predicted'] = orig_df['predicted'].fillna(0).astype('int')
# print(orig_df.tail(n_days_prediction))
# arrow
# peakx = 172
# peak = orig_df.iloc[peakx]
# peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['predicted']))
# _ = ax.annotate(
# peak_desc,
# xy=(peak['Date'] - dt.timedelta(days=1), peak['predicted']),
# xytext=(peak['Date'] - dt.timedelta(days=45), peak['predicted'] * .9),
# arrowprops={},
# bbox={'facecolor':'white'}
# )
# _ = ax.axvline(x=peak['Date'], linewidth=1, color='r')
r=rq.get('https://api.covid19india.org/v3/min/timeseries.min.json')
ts = r.json()
data = []
for state in ts:
for date in ts[state]:
data.append((state, date, ts[state][date]['total'].get('confirmed', 0)))
states_df = pd.DataFrame(data, columns=['state', 'date', 'total'])
states_df['date'] = pd.to_datetime(states_df['date'])
first_case_date = states_df['date'].min()
# http://www.populationu.com/india-population
STT_INFO = {
'AN' : {"name": "Andaman & Nicobar Islands", "popn": 450000},
'AP' : {"name": "Andhra Pradesh", "popn": 54000000},
'AR' : {"name": "Arunachal Pradesh", "popn": 30000000},
'AS' : {"name": "Asaam", "popn": 35000000},
'BR' : {"name": "Bihar", "popn": 123000000},
'CH' : {"name": "Chandigarh", "popn": 1200000},
'CT' : {"name": "Chhattisgarh", "popn": 29000000},
'DL' : {"name": "Delhi", "popn": 19500000},
'DN' : {"name": "Dadra & Nagar Haveli and Daman & Diu", "popn": 700000},
'GA' : {"name": "Goa", "popn": 1580000},
'GJ' : {"name": "Gujarat", "popn": 65000000},
'HP' : {"name": "Himachal Pradesh", "popn": 7400000},
'HR' : {"name": "Haryana", "popn": 28000000},
'JH' : {"name": "Jharkhand", "popn": 38000000},
'JK' : {"name": "Jammu & Kashmir", "popn": 13600000},
'KA' : {"name": "Karnataka", "popn": 67000000},
'KL' : {"name": "Kerala", "popn": 36000000},
'LA' : {"name": "Ladakh", "popn": 325000},
'MH' : {"name": "Maharashtra", "popn": 122000000},
'ML' : {"name": "Meghalaya", "popn": 3400000},
'MN' : {"name": "Manipur", "popn": 3000000},
'MP' : {"name": "Madhya Pradesh", "popn": 84000000},
'MZ' : {"name": "Mizoram", "popn": 1200000},
'NL' : {"name": "Nagaland", "popn": 2200000},
'OR' : {"name": "Odisha", "popn": 46000000},
'PB' : {"name": "Punjab", "popn": 30000000},
'PY' : {"name": "Puducherry", "popn": 1500000},
'RJ' : {"name": "Rajasthan", "popn": 80000000},
'TG' : {"name": "Telangana", "popn": 39000000},
'TN' : {"name": "Tamil Nadu", "popn": 77000000},
'TR' : {"name": "Tripura", "popn": 4100000},
'UP' : {"name": "Uttar Pradesh", "popn": 235000000},
'UT' : {"name": "Uttarakhand", "popn": 11000000},
'WB' : {"name": "West Bengal", "popn": 98000000},
# 'SK' : {"name": "Sikkim", "popn": 681000},
# 'UN' : {"name": "Unassigned", "popn": 40000000}, #avg pop
# 'LD' : {"name": "Lakshadweep", "popn": 75000}
}
# uncomment for India
# STT_INFO = {
# 'TT' : {"name": "India", "popn": 1387155000}
# }
# dummy data for testing
# SET 1 - 10 states
# STT_INFO = {
# 'A': {"name": "Apple", "popn": 10000000},
# 'B': {"name": "Berry", "popn": 10000000},
# 'C': {"name": "Cherry", "popn": 10000000},
# 'D': {"name": "Dates", "popn": 10000000},
# 'E': {"name": "Elderberry", "popn": 10000000},
# 'F': {"name": "Fig", "popn": 10000000},
# 'G': {"name": "Grape", "popn": 10000000},
# 'H': {"name": "Honeysuckle", "popn": 10000000},
# 'I': {"name": "Icaco", "popn": 10000000},
# 'J': {"name": "Jujube", "popn": 10000000},
# }
# total = 100
# SET 2 - 1 agg state
STT_INFO = {
'Z': {"name": "FruitCountry1000x", "popn": 10000000},
}
total = 1000
r = {
'state': [],
'date': [],
'total': []
}
start_date = dt.datetime(day=1, month=3, year=2020)
end_date = dt.datetime.now()
while start_date <= end_date:
for s in STT_INFO:
r['state'].append(s)
r['date'].append(start_date)
r['total'].append(total)
total *= 1.03
start_date += dt.timedelta(days=1)
states_df = pd.DataFrame(r)
states_df['date'] = pd.to_datetime(states_df['date'])
states_df.tail()
def expand(df):
'''Fill missing dates in an irregular timeline'''
min_date = df['date'].min()
max_date = df['date'].max()
idx = pd.date_range(min_date, max_date)
df.index = pd.DatetimeIndex(df.date)
df = df.drop(columns=['date'])
return df.reindex(idx, method='pad').reset_index().rename(columns={'index':'date'})
def prefill(df, min_date):
'''Fill zeros from first_case_date to df.date.min()'''
assert(len(df.state.unique()) == 1)
s = df.state.unique().item()
min_date = min_date
max_date = df['date'].max()
idx = pd.date_range(min_date, max_date)
df.index = pd.DatetimeIndex(df.date)
df = df.drop(columns=['date'])
return df.reindex(idx).reset_index().rename(columns={'index':'date'}).fillna({'state':s, 'total':0})
prediction_offset = 1 # how many days of data to skip
n_days_prediction = 200 # number of days for prediction
n_days_data = len(expand(states_df.loc[states_df['state']=='TT']))
assert(n_days_prediction%OP_SEQ_LEN == 0)
agg_days = n_days_data - prediction_offset + n_days_prediction # number of days for plotting agg curve i.e. prediction + actual data
states_agg = np.zeros(agg_days)
ax = None
api = {}
for state in STT_INFO:
pop_fct = STT_INFO[state]["popn"] / 1000
state_df = states_df.loc[states_df['state']==state][:-prediction_offset] # skip todays data. covid19 returns incomplete.
state_df = prefill(expand(state_df), first_case_date)
state_df['daily'] = state_df['total'] - state_df['total'].shift(1).fillna(0)
test_data = np.array(state_df['daily'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)
in_data = test_data[-IP_SEQ_LEN:]
out_data = np.array([], dtype=np.float32)
for i in range(int(n_days_prediction / OP_SEQ_LEN)):
ip = torch.tensor(
in_data,
dtype=torch.float32
).to(device)
try:
pred = model.predict(ip.view(-1, IP_SEQ_LEN, 1))
except Exception as e:
print(state, e)
in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:], pred.cpu().numpy())
out_data = np.append(out_data, pred.cpu().numpy())
sn = STT_INFO[state]['name']
orig_df = pd.DataFrame({
'actual': np.array(test_data * pop_fct, dtype=np.int)
})
fut_df = pd.DataFrame({
'predicted': np.array(out_data * pop_fct, dtype=np.int)
})
# print(fut_df.to_csv(sep='|'))
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
orig_df[sn] = orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)
orig_df['total'] = orig_df[sn].cumsum()
states_agg += np.array(orig_df[sn][-agg_days:].fillna(0))
# generate date col for orig_df from state_df
start_date = state_df['date'].iloc[0]
orig_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(orig_df))])
# if orig_df[sn].max() < 10000: # or orig_df[sn].max() < 5000:
# continue
# print state, peak date, peak daily cases, cumulative since beginning
peak = orig_df.loc[orig_df[sn].idxmax()]
print(sn, "|", peak['Date'].strftime("%b %d"), "|", int(peak[sn]), "|", int(orig_df['total'].iloc[-1]))
# export data for API
orig_df['deceased_daily'] = orig_df[sn] * 0.028
orig_df['recovered_daily'] = orig_df[sn].shift(14, fill_value=0) - orig_df['deceased_daily'].shift(7, fill_value=0)
orig_df['active_daily'] = orig_df[sn] - orig_df['recovered_daily'] - orig_df['deceased_daily']
api[state] = {}
for idx, row in orig_df[-agg_days:].iterrows():
row_date = row['Date'].strftime("%Y-%m-%d")
api[state][row_date] = {
"delta": {
"confirmed": int(row[sn]),
"deceased": int(row['deceased_daily']),
"recovered": int(row['recovered_daily']),
"active": int(row['active_daily'])
}
}
# plot state chart
ax = orig_df.plot(
x='Date',
y=[sn],
title='Daily Cases',
figsize=(15,10),
grid=True,
ax=ax,
lw=3
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)
# plot aggregate chart
cum_df = pd.DataFrame({
'states_agg': states_agg
})
last_date = orig_df['Date'].iloc[-1].to_pydatetime()
start_date = last_date - dt.timedelta(days=agg_days)
cum_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(cum_df))])
ax = cum_df.plot(
x='Date',
y=['states_agg'],
title='Aggregate daily cases',
figsize=(15,10),
grid=True,
lw=3
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)
# plot peak in agg
# peakx = 171
# peak = cum_df.iloc[peakx]
# peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['states_agg']))
# _ = ax.annotate(
# peak_desc,
# xy=(peak['Date'] + dt.timedelta(days=1), peak['states_agg']),
# xytext=(peak['Date'] + dt.timedelta(days=45), peak['states_agg'] * .9),
# arrowprops={},
# bbox={'facecolor':'white'}
# )
# _ = ax.axvline(x=peak['Date'], linewidth=1, color='r')
# aggregate predictions
api['TT'] = {}
for state in api:
if state == 'TT':
continue
for date in api[state]:
api['TT'][date] = api['TT'].get(date, {'delta':{}, 'total':{}})
for k in ['delta']: #'total'
api['TT'][date][k]['confirmed'] = api['TT'][date][k].get('confirmed', 0) + api[state][date][k]['confirmed']
api['TT'][date][k]['deceased'] = api['TT'][date][k].get('deceased', 0) + api[state][date][k]['deceased']
api['TT'][date][k]['recovered'] = api['TT'][date][k].get('recovered', 0) + api[state][date][k]['recovered']
api['TT'][date][k]['active'] = api['TT'][date][k].get('active', 0) + api[state][date][k]['active']
# export
with open("predictions.json", "w") as f:
f.write(json.dumps(api, sort_keys=True))
# aggregate predictions
api['TT'] = {}
for state in api:
if state == 'TT':
continue
for date in api[state]:
api['TT'][date] = api['TT'].get(date, {})
api['TT'][date]['c'] = api['TT'][date].get('c', 0) + api[state][date]['delta']['confirmed']
api['TT'][date]['d'] = api['TT'][date].get('d', 0) + api[state][date]['delta']['deceased']
api['TT'][date]['r'] = api['TT'][date].get('r', 0) + api[state][date]['delta']['recovered']
api['TT'][date]['a'] = api['TT'][date].get('a', 0) + api[state][date]['delta']['active']
# cumulative
# t = {'c':0, 'd':0, 'r':0, 'a':0}
# for date in sorted(api['TT'].keys()):
# for k in ['c', 'd', 'r', 'a']:
# api['TT'][date][k] += t[k] # add cum to today
# t[k] = api['TT'][date][k] # udpate cum
# read previous and export
k = (states_df.date.max().to_pydatetime() - dt.timedelta(days=prediction_offset)).strftime("%Y-%m-%d")
try:
with open("vp.json", "r") as f:
out = json.loads(f.read())
except Exception as e:
out = {}
with open("vp.json", "w") as f:
out[k] = {'TT': api['TT']}
f.write(json.dumps(out, sort_keys=True))
df_csv = pd.DataFrame(out[k]['TT'])
df_csv = df_csv.transpose()
df_csv['c'].to_csv('vp_' + k + '.csv')