Useful links
import pandas as pd
import numpy as np
import requests as rq
import datetime as dt
import traceback as tb
import torch
tnn = torch.nn
top = torch.optim
from torch.utils import data as tdt
# get daily time series data from covid19india
resp = rq.get("https://api.covid19india.org/data.json")
ts = resp.json()['cases_time_series']
r = {
"date": [],
"confirmed": [],
"deceased": [],
"recovered": []
}
for d in ts:
try:
r['date'].append(dt.datetime.strptime(d['date'] + "2020", '%d %B %Y'))
r['confirmed'].append(np.int64(d['dailyconfirmed']))
r['deceased'].append(np.int64(d['dailydeceased']))
r['recovered'].append(np.int64(d['dailyrecovered']))
except Exception as e:
print(d['date'])
tb.print_exc()
df = pd.DataFrame(r)
df.sort_values('date', inplace=True)
df.sample()
date | confirmed | deceased | recovered | |
---|---|---|---|---|
45 | 2020-03-15 | 10 | 0 | 3 |
it_df = pd.read_csv('csv/italy.csv')
print(it_df.sample())
es_df = pd.read_csv('csv/spain.csv')
print(es_df.sample())
date deceased recovered confirmed active daily 5 2020-02-26 12 3.0 445 430 123 date deaths recovered confirmed active daily 39 2020-04-03 11164.0 30173 133017 91680 6680
def get_rnn_dataset(series, seq_len):
"""get rnn training dataset, given a series and seq_len"""
ip_seq = np.array(series[:-1], dtype=np.float32)
op_seq = np.array(series[1:], dtype=np.float32)
data_len = len(series) - 1
ips = []
ops = []
for i in range(data_len - seq_len + 1):
ip, op = ip_seq[i : i+seq_len], op_seq[i : i+seq_len]
ips.append(ip)
ops.append(op)
ips = np.array(ips, dtype=np.float32)
ops = np.array(ops, dtype=np.float32)
# convert from numpy to torch
ip_t = torch.from_numpy(ips)
op_t = torch.from_numpy(ops)
dataset = tdt.TensorDataset(ip_t, op_t)
return dataset
SEQ_LEN = 5
VAL_SPLIT = 0.5
IN_POP_FCT = 1300000 # 130Cr / 1000
IT_POP_FCT = 60000 # 6Cr / 1000
ES_POP_FCT = 47000 # 4.7Cr / 1000
# preprocess data: sma(6) and normalise by 10k
in_cnf = np.array(df['confirmed'][37:].rolling(6, center=True, min_periods=1).mean()) / IN_POP_FCT
it_cnf = np.array(it_df['daily'].rolling(6, center=True, min_periods=1).mean()) / IT_POP_FCT
es_cnf = np.array(es_df['daily'].rolling(6, center=True, min_periods=1).mean()) / ES_POP_FCT
# Choose among India or Italy or Spain here
cnf = es_cnf # np.append(it_cnf, es_cnf)
dataset = get_rnn_dataset(cnf, SEQ_LEN)
val_len = int(VAL_SPLIT * len(dataset))
train_len = len(dataset) - val_len
train_set, val_set = tdt.random_split(dataset, (train_len, val_len))
trn_loader = tdt.DataLoader(train_set, shuffle=True, batch_size=1)
val_loader = tdt.DataLoader(val_set, shuffle=True, batch_size=1)
all_loader = tdt.DataLoader(dataset, shuffle=False, batch_size=1)
class Forecaster(tnn.Module):
def __init__(self, seq_len=1, hidden_size=1, num_layers=1):
super(Forecaster, self).__init__()
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = tnn.RNN(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers)
# self.gru = tnn.GRU(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers)
# self.lstm = tnn.LSTM(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers)
self.linear = tnn.Linear(self.hidden_size, 1)
self.sigmoid = tnn.Sigmoid()
def forward(self, ip, h=None):
rnn_out, rnn_h = self.rnn(ip, h)
# return (torch.squeeze(rnn_out), torch.squeeze(rnn_h))
# linear layer
final_out = []
for i in range(self.seq_len):
linear_out = self.linear(rnn_out[i].view(self.hidden_size))
sigmoid_out = self.sigmoid(linear_out)
final_out.append(sigmoid_out)
return torch.stack(final_out).view(self.seq_len), rnn_h
def predict(self, ip, num_predictions=1):
preds = []
with torch.no_grad():
pred = ip
h = None
while(num_predictions):
pred, h = self.forward(pred.view(self.seq_len, 1, 1), h)
preds.append(pred.view(self.seq_len)[-1])
num_predictions -= 1
return torch.stack(preds)
HIDDEN_SIZE = 2
NUM_LAYERS = 1
LEARNING_RATE = 0.001
NUM_EPOCHS = 5000
model = Forecaster(seq_len=SEQ_LEN, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
loss_fn = tnn.MSELoss()
optimizer = top.Adam(model.parameters(), lr=LEARNING_RATE)
# TRAIN
trn_loss_vals = []
val_loss_vals = []
for e in range(NUM_EPOCHS):
model.train()
trn_losses = []
for i, data in enumerate(trn_loader):
ip, op = data
optimizer.zero_grad() # set grads to 0
pred, _ = model(ip.view(SEQ_LEN, 1, 1)) # predict
loss = loss_fn(pred, op.view(SEQ_LEN)) # calc loss
loss.backward() # calc and assign grads
optimizer.step() # update weights
trn_losses.append(loss) # logging
avg_trn_loss = torch.stack(trn_losses).mean().item()
trn_loss_vals.append(avg_trn_loss)
model.eval()
with torch.no_grad():
val_losses = []
for i, data in enumerate(val_loader):
ip, op = data
pred, _ = model(ip.view(SEQ_LEN, 1, 1))
loss = loss_fn(pred, op.view(SEQ_LEN))
val_losses.append(loss)
avg_val_loss = torch.stack(val_losses).mean().item()
val_loss_vals.append(avg_val_loss)
if e%100==0:
print("epoch:", f"{e:3}", "avg_val_loss:", f"{avg_val_loss: .5f}", "avg_trn_loss:", f"{avg_trn_loss: .5f}")
df_trn_loss = pd.DataFrame({
'trn_loss': trn_loss_vals,
'val_loss': val_loss_vals
})
_ = df_trn_loss.plot(
y=['trn_loss', 'val_loss'],
title=['Training loss per epoch', 'Validation loss per epoch'],
subplots=True,
figsize=(5,8),
sharex=False,
logy=True
)
epoch: 0 avg_val_loss: 0.23072 avg_trn_loss: 0.25161 epoch: 20 avg_val_loss: 0.00911 avg_trn_loss: 0.01122 epoch: 40 avg_val_loss: 0.00259 avg_trn_loss: 0.00375 epoch: 60 avg_val_loss: 0.00199 avg_trn_loss: 0.00298 epoch: 80 avg_val_loss: 0.00193 avg_trn_loss: 0.00286 epoch: 100 avg_val_loss: 0.00193 avg_trn_loss: 0.00283 epoch: 120 avg_val_loss: 0.00191 avg_trn_loss: 0.00281 epoch: 140 avg_val_loss: 0.00189 avg_trn_loss: 0.00279 epoch: 160 avg_val_loss: 0.00186 avg_trn_loss: 0.00275 epoch: 180 avg_val_loss: 0.00178 avg_trn_loss: 0.00266 epoch: 200 avg_val_loss: 0.00160 avg_trn_loss: 0.00242 epoch: 220 avg_val_loss: 0.00133 avg_trn_loss: 0.00205 epoch: 240 avg_val_loss: 0.00098 avg_trn_loss: 0.00155 epoch: 260 avg_val_loss: 0.00057 avg_trn_loss: 0.00095 epoch: 280 avg_val_loss: 0.00033 avg_trn_loss: 0.00056 epoch: 300 avg_val_loss: 0.00024 avg_trn_loss: 0.00041 epoch: 320 avg_val_loss: 0.00022 avg_trn_loss: 0.00036 epoch: 340 avg_val_loss: 0.00021 avg_trn_loss: 0.00034 epoch: 360 avg_val_loss: 0.00021 avg_trn_loss: 0.00033 epoch: 380 avg_val_loss: 0.00021 avg_trn_loss: 0.00032 epoch: 400 avg_val_loss: 0.00020 avg_trn_loss: 0.00031 epoch: 420 avg_val_loss: 0.00021 avg_trn_loss: 0.00031 epoch: 440 avg_val_loss: 0.00020 avg_trn_loss: 0.00031 epoch: 460 avg_val_loss: 0.00020 avg_trn_loss: 0.00030 epoch: 480 avg_val_loss: 0.00020 avg_trn_loss: 0.00030 epoch: 500 avg_val_loss: 0.00020 avg_trn_loss: 0.00029 epoch: 520 avg_val_loss: 0.00019 avg_trn_loss: 0.00029 epoch: 540 avg_val_loss: 0.00020 avg_trn_loss: 0.00028 epoch: 560 avg_val_loss: 0.00019 avg_trn_loss: 0.00029 epoch: 580 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 600 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 620 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 640 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 660 avg_val_loss: 0.00020 avg_trn_loss: 0.00029 epoch: 680 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 700 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 720 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 740 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 760 avg_val_loss: 0.00019 avg_trn_loss: 0.00027 epoch: 780 avg_val_loss: 0.00019 avg_trn_loss: 0.00028 epoch: 800 avg_val_loss: 0.00018 avg_trn_loss: 0.00028 epoch: 820 avg_val_loss: 0.00018 avg_trn_loss: 0.00027 epoch: 840 avg_val_loss: 0.00018 avg_trn_loss: 0.00026 epoch: 860 avg_val_loss: 0.00018 avg_trn_loss: 0.00026 epoch: 880 avg_val_loss: 0.00017 avg_trn_loss: 0.00026 epoch: 900 avg_val_loss: 0.00015 avg_trn_loss: 0.00024 epoch: 920 avg_val_loss: 0.00014 avg_trn_loss: 0.00022 epoch: 940 avg_val_loss: 0.00012 avg_trn_loss: 0.00019 epoch: 960 avg_val_loss: 0.00011 avg_trn_loss: 0.00017 epoch: 980 avg_val_loss: 0.00010 avg_trn_loss: 0.00015 epoch: 1000 avg_val_loss: 0.00010 avg_trn_loss: 0.00015 epoch: 1020 avg_val_loss: 0.00009 avg_trn_loss: 0.00014 epoch: 1040 avg_val_loss: 0.00010 avg_trn_loss: 0.00014 epoch: 1060 avg_val_loss: 0.00009 avg_trn_loss: 0.00014 epoch: 1080 avg_val_loss: 0.00009 avg_trn_loss: 0.00013 epoch: 1100 avg_val_loss: 0.00009 avg_trn_loss: 0.00013 epoch: 1120 avg_val_loss: 0.00009 avg_trn_loss: 0.00013 epoch: 1140 avg_val_loss: 0.00009 avg_trn_loss: 0.00013 epoch: 1160 avg_val_loss: 0.00009 avg_trn_loss: 0.00013 epoch: 1180 avg_val_loss: 0.00008 avg_trn_loss: 0.00013 epoch: 1200 avg_val_loss: 0.00008 avg_trn_loss: 0.00012 epoch: 1220 avg_val_loss: 0.00008 avg_trn_loss: 0.00012 epoch: 1240 avg_val_loss: 0.00008 avg_trn_loss: 0.00012 epoch: 1260 avg_val_loss: 0.00008 avg_trn_loss: 0.00012 epoch: 1280 avg_val_loss: 0.00008 avg_trn_loss: 0.00011 epoch: 1300 avg_val_loss: 0.00008 avg_trn_loss: 0.00012 epoch: 1320 avg_val_loss: 0.00008 avg_trn_loss: 0.00011 epoch: 1340 avg_val_loss: 0.00007 avg_trn_loss: 0.00011 epoch: 1360 avg_val_loss: 0.00008 avg_trn_loss: 0.00011 epoch: 1380 avg_val_loss: 0.00007 avg_trn_loss: 0.00011 epoch: 1400 avg_val_loss: 0.00007 avg_trn_loss: 0.00011 epoch: 1420 avg_val_loss: 0.00008 avg_trn_loss: 0.00010 epoch: 1440 avg_val_loss: 0.00007 avg_trn_loss: 0.00010 epoch: 1460 avg_val_loss: 0.00007 avg_trn_loss: 0.00010 epoch: 1480 avg_val_loss: 0.00007 avg_trn_loss: 0.00010 epoch: 1500 avg_val_loss: 0.00007 avg_trn_loss: 0.00010 epoch: 1520 avg_val_loss: 0.00006 avg_trn_loss: 0.00010 epoch: 1540 avg_val_loss: 0.00006 avg_trn_loss: 0.00010 epoch: 1560 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1580 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1600 avg_val_loss: 0.00007 avg_trn_loss: 0.00009 epoch: 1620 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1640 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1660 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1680 avg_val_loss: 0.00007 avg_trn_loss: 0.00009 epoch: 1700 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1720 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1740 avg_val_loss: 0.00006 avg_trn_loss: 0.00009 epoch: 1760 avg_val_loss: 0.00006 avg_trn_loss: 0.00008 epoch: 1780 avg_val_loss: 0.00006 avg_trn_loss: 0.00008 epoch: 1800 avg_val_loss: 0.00006 avg_trn_loss: 0.00008 epoch: 1820 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1840 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1860 avg_val_loss: 0.00006 avg_trn_loss: 0.00008 epoch: 1880 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1900 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1920 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1940 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1960 avg_val_loss: 0.00005 avg_trn_loss: 0.00008 epoch: 1980 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2000 avg_val_loss: 0.00007 avg_trn_loss: 0.00008 epoch: 2020 avg_val_loss: 0.00006 avg_trn_loss: 0.00007 epoch: 2040 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2060 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2080 avg_val_loss: 0.00006 avg_trn_loss: 0.00007 epoch: 2100 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2120 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2140 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2160 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2180 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2200 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2220 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2240 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2260 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2280 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2300 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2320 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2340 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2360 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2380 avg_val_loss: 0.00005 avg_trn_loss: 0.00007 epoch: 2400 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2420 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2440 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2460 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2480 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2500 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2520 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2540 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2560 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2580 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2600 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2620 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2640 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2660 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2680 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2700 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2720 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2740 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2760 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2780 avg_val_loss: 0.00005 avg_trn_loss: 0.00006 epoch: 2800 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2820 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2840 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2860 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2880 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2900 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2920 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2940 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 2960 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 2980 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3000 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3020 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3040 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3060 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3080 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3100 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3120 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3140 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3160 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3180 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3200 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3220 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3240 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3260 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3280 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3300 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3320 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3340 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3360 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3380 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3400 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3420 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3440 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3460 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3480 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3500 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 3520 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3540 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3560 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3580 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3600 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3620 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3640 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3660 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3680 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3700 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3720 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3740 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3760 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3780 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3800 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3820 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3840 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3860 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3880 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3900 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3920 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3940 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3960 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 3980 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4000 avg_val_loss: 0.00005 avg_trn_loss: 0.00005 epoch: 4020 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4040 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4060 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4080 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4100 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4120 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4140 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4160 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4180 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4200 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4220 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4240 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 4260 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4280 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4300 avg_val_loss: 0.00004 avg_trn_loss: 0.00006 epoch: 4320 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4340 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4360 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4380 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4400 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4420 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4440 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4460 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4480 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4500 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4520 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4540 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4560 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4580 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4600 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4620 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4640 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4660 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4680 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4700 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4720 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4740 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4760 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4780 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4800 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4820 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4840 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4860 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4880 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4900 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4920 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4940 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4960 avg_val_loss: 0.00004 avg_trn_loss: 0.00005 epoch: 4980 avg_val_loss: 0.00004 avg_trn_loss: 0.00005
# CHOOSE FROM IN, IT OR ES HERE
df_cnf = in_cnf
# df_cnf = np.array([8000, 7600, 7300, 6900, 6650], dtype=np.float32) / 10000
test_in = df_cnf[-SEQ_LEN:]
t = torch.tensor(
test_in.reshape(SEQ_LEN, 1, 1),
dtype=torch.float32
)
print("IN:", t.view(SEQ_LEN) * IN_POP_FCT)
out = model.predict(t, num_predictions=30)
print("OUT:", out * IN_POP_FCT)
orig_df = pd.DataFrame({
'actual': df_cnf * IN_POP_FCT
})
fut_df = pd.DataFrame({
'predicted': (out.numpy() * IN_POP_FCT)
})
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
_ = orig_df.plot()
IN: tensor([5394.6670, 5685.0000, 5954.3335, 6219.6001, 6236.0000]) OUT: tensor([14522.6318, 18469.1309, 20658.3887, 21964.0020, 22775.1719, 23292.6777, 23629.5645, 23852.9922, 24004.1836, 24108.9355, 24183.5156, 24238.2090, 24279.5215, 24311.5332, 24336.8008, 24356.9375, 24372.9473, 24385.5215, 24395.1113, 24402.0801, 24406.7148, 24409.2734, 24410.0059, 24409.1367, 24406.8906, 24403.5020, 24399.1621, 24394.0645, 24388.3945, 24382.3008])
# change this according to country selected for training
pop_fct = ES_POP_FCT
pred_vals = []
out_vals = []
for data in all_loader:
ip, op = data
pred = model.predict(ip.view(SEQ_LEN, 1, 1))
pred_vals.append(pred.item() * pop_fct)
out_vals.append(op.view(SEQ_LEN)[-1].item() * pop_fct)
cmp_df = pd.DataFrame({
'predicted cases': pred_vals,
'actual cases': out_vals
})
_ = cmp_df.plot()
for k in model.state_dict().keys():
print(k, "=", model.state_dict()[k])
rnn.weight_ih_l0 = tensor([[-4.7741], [ 0.7617]]) rnn.weight_hh_l0 = tensor([[-0.1013, -0.1580], [ 0.1943, -0.5698]]) rnn.bias_ih_l0 = tensor([0.2483, 0.0914]) rnn.bias_hh_l0 = tensor([ 0.1336, -0.1624]) linear.weight = tensor([[-0.2352, -0.0221]]) linear.bias = tensor([0.0798])