import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
################################### for generating synthetic data
from scipy.stats import lognorm,gamma
from scipy.optimize import brentq
################################### for neural network modeling
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
Each function returns a sequence of event times and the average negative log-likelihood of the true model. Please see the main text for details of each model.
######################################################
### stationary poisson process
######################################################
def generate_stationary_poisson():
tau = np.random.exponential(size=100000)
T = tau.cumsum()
score = 1
return [T,score]
######################################################
### non-stationary poisson process
######################################################
def generate_nonstationary_poisson():
L = 20000
amp = 0.99
l_t = lambda t: np.sin(2*np.pi*t/L)*amp + 1
l_int = lambda t1,t2: - L/(2*np.pi)*( np.cos(2*np.pi*t2/L) - np.cos(2*np.pi*t1/L) )*amp + (t2-t1)
while 1:
T = np.random.exponential(size=210000).cumsum()*0.5
r = np.random.rand(210000)
index = r < l_t(T)/2.0
if index.sum() > 100000:
T = T[index][:100000]
score = - ( np.log(l_t(T[80000:])).sum() - l_int(T[80000-1],T[-1]) )/20000
break
return [T,score]
######################################################
### stationary renewal process
######################################################
def generate_stationary_renewal():
s = np.sqrt(np.log(6*6+1))
mu = -s*s/2
tau = lognorm.rvs(s=s,scale=np.exp(mu),size=100000)
lpdf = lognorm.logpdf(tau,s=s,scale=np.exp(mu))
T = tau.cumsum()
score = - np.mean(lpdf[80000:])
return [T,score]
######################################################
### non-stationary renewal process
######################################################
def generate_nonstationary_renewal():
L = 20000
amp = 0.99
l_t = lambda t: np.sin(2*np.pi*t/L)*amp + 1
l_int = lambda t1,t2: - L/(2*np.pi)*( np.cos(2*np.pi*t2/L) - np.cos(2*np.pi*t1/L) )*amp + (t2-t1)
T = []
lpdf = []
x = 0
k = 4
rs = gamma.rvs(k,size=100000)
lpdfs = gamma.logpdf(rs,k)
rs = rs/k
lpdfs = lpdfs + np.log(k)
for i in range(100000):
x_next = brentq(lambda t: l_int(x,t) - rs[i],x,x+1000)
l = l_t(x_next)
T.append(x_next)
lpdf.append( lpdfs[i] + np.log(l) )
x = x_next
T = np.array(T)
lpdf = np.array(lpdf)
score = - lpdf[80000:].mean()
return [T,score]
######################################################
### self-correcting process
######################################################
def generate_self_correcting():
def self_correcting_process(mu,alpha,n):
t = 0; x = 0;
T = [];
log_l = [];
Int_l = [];
for i in range(n):
e = np.random.exponential()
tau = np.log( e*mu/np.exp(x) + 1 )/mu # e = ( np.exp(mu*tau)- 1 )*np.exp(x) /mu
t = t+tau
T.append(t)
x = x + mu*tau
log_l.append(x)
Int_l.append(e)
x = x -alpha
return [np.array(T),np.array(log_l),np.array(Int_l)]
[T,log_l,Int_l] = self_correcting_process(1,1,100000)
score = - ( log_l[80000:] - Int_l[80000:] ).sum() / 20000
return [T,score]
######################################################
### Hawkes process
######################################################
def generate_hawkes1():
[T,LL] = simulate_hawkes(100000,0.2,[0.8,0.0],[1.0,20.0])
score = - LL[80000:].mean()
return [T,score]
def generate_hawkes2():
[T,LL] = simulate_hawkes(100000,0.2,[0.4,0.4],[1.0,20.0])
score = - LL[80000:].mean()
return [T,score]
def simulate_hawkes(n,mu,alpha,beta):
T = []
LL = []
x = 0
l_trg1 = 0
l_trg2 = 0
l_trg_Int1 = 0
l_trg_Int2 = 0
mu_Int = 0
count = 0
while 1:
l = mu + l_trg1 + l_trg2
step = np.random.exponential()/l
x = x + step
l_trg_Int1 += l_trg1 * ( 1 - np.exp(-beta[0]*step) ) / beta[0]
l_trg_Int2 += l_trg2 * ( 1 - np.exp(-beta[1]*step) ) / beta[1]
mu_Int += mu * step
l_trg1 *= np.exp(-beta[0]*step)
l_trg2 *= np.exp(-beta[1]*step)
l_next = mu + l_trg1 + l_trg2
if np.random.rand() < l_next/l: #accept
T.append(x)
LL.append( np.log(l_next) - l_trg_Int1 - l_trg_Int2 - mu_Int )
l_trg1 += alpha[0]*beta[0]
l_trg2 += alpha[1]*beta[1]
l_trg_Int1 = 0
l_trg_Int2 = 0
mu_Int = 0
count += 1
if count == n:
break
return [np.array(T),np.array(LL)]
######################################################
### constant hazard function
######################################################
class HAZARD_const():
class Layer_LL(layers.Layer):
def __init__(self, **kwargs):
super(HAZARD_const.Layer_LL, self).__init__(**kwargs)
def build(self, input_shape):
self.build = True
def call(self, inputs):
x = inputs[0]; p = inputs[1];
log_l = p
Int_l = K.exp( p ) * x
return [log_l,Int_l]
def compute_output_shape(self, input_shape):
return [input_shape[0],input_shape[0]]
def __call__(self,inputs):
x = inputs[0]; rnn = inputs[1];
p = layers.Dense(1)(rnn)
[log_l,Int_l] = HAZARD_const.Layer_LL()([x,p])
LL = layers.Subtract()([log_l,Int_l])
return [LL,log_l,Int_l]
######################################################
### exponential hazard function
######################################################
class HAZARD_exp():
class Layer_LL(layers.Layer):
def __init__(self, **kwargs):
super(HAZARD_exp.Layer_LL, self).__init__(**kwargs)
def build(self, input_shape):
self.a = self.add_weight(name='a', initializer= keras.initializers.Constant(value=1.0), shape=(), trainable=True)
self.build = True
def call(self, inputs):
x = inputs[0]; p = inputs[1];
a = self.a;
log_l = p - a*x
Int_l = K.exp( p ) * ( 1 - K.exp(-a*x) ) / a
return [log_l,Int_l]
def compute_output_shape(self, input_shape):
return [input_shape[0],input_shape[0]]
def __call__(self,inputs):
x = inputs[0]; rnn = inputs[1];
p = layers.Dense(1)(rnn)
[log_l,Int_l] = HAZARD_exp.Layer_LL()([x,p])
LL = layers.Subtract()([log_l,Int_l])
return [LL,log_l,Int_l]
######################################################
### piecewise constant hazard function
######################################################
class HAZARD_pc():
class Layer_LL(layers.Layer):
def __init__(self, size_div,t_max,**kwargs):
self.size_div = size_div
self.t_max = t_max
self.bin_l = K.constant(np.linspace(0,t_max,size_div+1)[:-1].reshape(1,-1))
self.bin_r = K.constant(np.linspace(0,t_max,size_div+1)[1:].reshape(1,-1))
self.width = K.constant(t_max/size_div)
self.ones = K.constant(np.ones([size_div,1]))
super(HAZARD_pc.Layer_LL, self).__init__(**kwargs)
def build(self, input_shape):
self.build = True
def call(self, inputs):
x = inputs[0]; p = inputs[1];
r_le = K.cast( K.greater_equal(x,self.bin_l), dtype=K.floatx() )
r_er = K.cast( K.less( x,self.bin_r), dtype=K.floatx() )
r_e = r_er*r_le
r_l = 1-r_er
log_l = K.log(K.dot(p*r_e,self.ones))
Int_l = K.dot(p*r_l*self.width,self.ones) + K.dot(p*(x-self.bin_l)*r_e,self.ones)
return [log_l,Int_l]
def compute_output(self,input_shape):
return [input_shape[0],input_shape[0],input_shape[0]]
def __init__(self,size_div,t_max):
self.size_div = size_div
self.t_max = t_max
def __call__(self,inputs):
x = inputs[0]; rnn = inputs[1];
p = layers.Dense(self.size_div,activation='softplus')(rnn)
[log_l,Int_l] = HAZARD_pc.Layer_LL(self.size_div,self.t_max)([x,p])
LL = layers.Subtract()([log_l,Int_l])
return [LL,log_l,Int_l]
######################################################
### neural network based hazard function
######################################################
class HAZARD_NN():
def __init__(self, size_layer, size_nn, log_mode=True):
self.size_layer = size_layer
self.size_nn = size_nn
self.log_mode = log_mode
def __call__(self,inputs):
x = inputs[0]; rnn = inputs[1];
if self.log_mode:
x_nmlz = layers.Lambda(lambda x: (K.log(x)-self.mu_x)/self.sigma_x )(x)
else:
x_nmlz = layers.Lambda(lambda x: (x-self.mu_x)/self.sigma_x )(x)
def abs_glorot_uniform(shape, dtype=None, partition_info=None):
return K.abs(keras.initializers.glorot_uniform(seed=None)(shape,dtype=dtype))
hidden_x = layers.Dense(self.size_nn,kernel_initializer=abs_glorot_uniform,kernel_constraint=keras.constraints.NonNeg(),use_bias=False)(x_nmlz)
hidden_p = layers.Dense(self.size_nn)(rnn)
hidden = layers.Add()([hidden_x,hidden_p])
hidden = layers.Activation('tanh')(hidden)
for i in range(self.size_layer-1):
hidden = layers.Dense(self.size_nn,activation='tanh',kernel_initializer=abs_glorot_uniform,kernel_constraint=keras.constraints.NonNeg())(hidden)
Int_l = layers.Dense(1, activation='softplus',kernel_initializer=abs_glorot_uniform, kernel_constraint=keras.constraints.NonNeg() )(hidden)
log_l = layers.Lambda(lambda inputs: K.log( 1e-10 + K.gradients(inputs[0],inputs[1])[0] ))([Int_l,x])
LL = layers.Subtract()([log_l,Int_l])
return [LL,log_l,Int_l]
def normalize_input(self,x):
if self.log_mode:
self.mu_x = np.log(x).mean()
self.sigma_x = np.log(x).std()
else:
self.mu_x = x.mean()
self.sigma_x = x.std()
return self
######################################################
### RNN
######################################################
class RNN_PP():
def __init__(self, size_rnn, time_step, log_mode=True):
self.size_rnn= size_rnn
self.time_step = time_step
self.log_mode = log_mode
def __call__(self,inputs):
x = inputs
if self.log_mode:
x_nmlz = layers.Lambda(lambda x: (K.log(x)-self.mu_x)/self.sigma_x )(x)
else:
x_nmlz = layers.Lambda(lambda x: (x-self.mu_x)/self.sigma_x )(x)
rnn = layers.SimpleRNN(self.size_rnn,input_shape=(self.time_step,1),activation='tanh')(x_nmlz)
return rnn
def normalize_input(self,x):
if self.log_mode:
self.mu_x = np.log(x).mean()
self.sigma_x = np.log(x).std()
else:
self.mu_x = x.mean()
self.sigma_x = x.std()
return self
######################################################################################
### a class for a recurrent neural network based model for temporal point processes
######################################################################################
class NPP():
def __init__(self,time_step,size_rnn,type_hazard,size_layer=2,size_nn=64,size_div=128,log_mode=True):
self.time_step = time_step
self.size_rnn = size_rnn
self.type_hazard = type_hazard
self.size_layer = size_layer
self.size_nn = size_nn
self.log_mode = log_mode
self.size_div = size_div
def set_data(self,T):
def rolling_matrix(x,window_size):
x = x.flatten()
n = x.shape[0]
stride = x.strides[0]
return np.lib.stride_tricks.as_strided(x, shape=(n-window_size+1, window_size), strides=(stride,stride) ).copy()
def transform_data(T,n_train,n_test,time_step):
np.random.seed(0)
index_shuffle = np.random.permutation(n_train-time_step-1)
dT_train = np.ediff1d(T[:n_train])
r_dT_train = rolling_matrix(dT_train,time_step+1)[index_shuffle]
dT_test = np.ediff1d(T[n_train-time_step-1:n_train+n_test])
r_dT_test = rolling_matrix(dT_test,time_step+1)
dT_train_input = r_dT_train[:,:-1].reshape(-1,time_step,1)
dT_train_target = r_dT_train[:,[-1]]
dT_test_input = r_dT_test[:,:-1].reshape(-1,time_step,1)
dT_test_target = r_dT_test[:,[-1]]
return [dT_train_input,dT_train_target,dT_test_input,dT_test_target]
n = T.shape[0]
[dT_train_input,dT_train_target,dT_test_input,dT_test_target] = transform_data(T,int(n*0.8),n-int(n*0.8),self.time_step) # A sequence is divided into training and test data.
self.dT_train_input = dT_train_input
self.dT_train_target = dT_train_target
self.dT_test_input = dT_test_input
self.dT_test_target = dT_test_target
self.n_train = dT_train_target.shape[0]
self.n_test = dT_test_target.shape[0]
self.t_max = np.max(np.vstack([self.dT_train_target,self.dT_test_target]))*1.001
#print("n_train: %d, n_test: %d, t_max: %f" % (self.n_train,self.n_test,self.t_max) )
return self
def set_model(self):
if self.type_hazard == 'const':
self.layer_hazard = HAZARD_const()
elif self.type_hazard == 'exp':
self.layer_hazard = HAZARD_exp()
elif self.type_hazard == 'pc':
self.layer_hazard = HAZARD_pc(size_div=self.size_div,t_max=self.t_max)
elif self.type_hazard == 'NN':
self.layer_hazard = HAZARD_NN(size_layer=self.size_layer,size_nn=self.size_nn,log_mode=self.log_mode).normalize_input(self.dT_train_target)
self.rnn = RNN_PP(size_rnn=self.size_rnn,time_step=self.time_step,log_mode=self.log_mode).normalize_input(self.dT_train_target)
input_dT = layers.Input(shape=(self.time_step,1))
input_x = layers.Input(shape=(1,))
output_rnn = self.rnn(input_dT)
[LL,log_l,Int_l] = self.layer_hazard([input_x,output_rnn])
self.model = Model(inputs=[input_dT,input_x],outputs=[LL,log_l,Int_l])
self.model.add_loss(-K.mean(LL))
#model.summary()
return self
def compile(self,lr=1e-3):
self.model.compile(keras.optimizers.Adam(lr=lr))
return self
def scores(self):
scores = - self.model.predict([self.dT_test_input,self.dT_test_target],batch_size=self.n_test)[0].flatten()
return scores
class CustomEarlyStopping(keras.callbacks.Callback):
def __init__(self):
super(NPP.CustomEarlyStopping, self).__init__()
self.best_val_loss = 100000
self.history_val_loss = []
self.best_weights = None
def on_epoch_end(self, epoch, logs=None):
val_loss = logs['val_loss']
self.history_val_loss = np.append(self.history_val_loss,val_loss)
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.best_weights = self.model.get_weights()
if self.best_val_loss + 0.05 < val_loss:
self.model.stop_training = True
if (epoch+1) % 5 == 0:
#print('epoch: %d, current_val_loss: %f, min_val_loss: %f' % (epoch+1,val_loss,self.best_val_loss) )
if (epoch+1) >= 15:
if self.best_val_loss > self.history_val_loss[:-5].min() - 0.001:
self.model.stop_training = True
def on_train_end(self,logs=None):
self.model.set_weights(self.best_weights)
#print('set optimal weights')
def fit_eval(self,epochs=100,batch_size=256):
es = NPP.CustomEarlyStopping()
history = self.model.fit([self.dT_train_input,self.dT_train_target],epochs=epochs,batch_size=batch_size,validation_split=0.2,verbose=0,callbacks=[es])
scores = self.scores()
self.history = {'loss':np.array(history.history['loss']), 'val_loss':np.array(history.history['val_loss'])}
self.val_loss = es.best_val_loss
self.mnll = scores.mean()
self.mae = self.mean_absolute_error()
#print('score: %f' % scores.mean() )
#print()
return self
def bisect_target(self,taus):
return self.model.predict([self.dT_test_input,taus],batch_size=taus.shape[0])[2] - np.log(2)
def median_prediction(self,l,r):
for i in range(13):
c = (l+r)/2
v = self.bisect_target(c)
l = np.where(v<0,c,l)
r = np.where(v>=0,c,r)
return (l+r)/2
def mean_absolute_error(self):
l=np.mean(self.dT_train_target)*0.0001*np.ones_like(self.dT_test_target)
r=np.mean(self.dT_train_target)*100.0 *np.ones_like(self.dT_test_target)
tau_pred = self.median_prediction(l,r)
return np.mean(np.abs(tau_pred-self.dT_test_target))
class training():
def training(self,type_hazard,T):
min_val_loss = 100000
for time_step in [5,10,20,40]:
npp = NPP(time_step=time_step,type_hazard=type_hazard,size_rnn=64,size_layer=2,size_nn=64,size_div=128,log_mode=True).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
if npp.val_loss < min_val_loss:
min_val_loss = npp.val_loss
self.mnll = npp.mnll
self.mae = npp.mae
return self
A python class "NPP" splits the data into training and test data, trains the model using the training data, and evaluates the model using the test data.
Note that we fix the values of the hyper-parameters in this section for simplicity.
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_stationary_poisson() # generate synthetic data: stationary Poisson process.
print('#######################################')
print('## Stationary Poisson process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Stationary Poisson process ####################################### WARNING:tensorflow:Output subtract_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_20. WARNING:tensorflow:Output layer_ll_15 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_15. WARNING:tensorflow:Output layer_ll_15_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_15_1. WARNING:tensorflow:Output subtract_21 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_21. WARNING:tensorflow:Output layer_ll_16 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_16. WARNING:tensorflow:Output layer_ll_16_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_16_1. WARNING:tensorflow:Output subtract_22 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_22. WARNING:tensorflow:Output layer_ll_17 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_17. WARNING:tensorflow:Output layer_ll_17_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_17_1. WARNING:tensorflow:Output subtract_23 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_23. WARNING:tensorflow:Output lambda_35 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_35. WARNING:tensorflow:Output dense_41 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_41. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.981 (standardized score: -0.019), MAE: 0.679 ## exponential model MNLL: 0.981 (standardized score: -0.019), MAE: 0.678 ## piecewise constant model MNLL: 0.984 (standardized score: -0.016), MAE: 0.679 ## neural network based model MNLL: 0.982 (standardized score: -0.018), MAE: 0.679
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_nonstationary_poisson() # generate synthetic data: nonstationary Poisson process.
print('#######################################')
print('## Non-stationary Poisson process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Non-stationary Poisson process ####################################### WARNING:tensorflow:Output subtract_4 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_4. WARNING:tensorflow:Output layer_ll_3 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_3. WARNING:tensorflow:Output layer_ll_3_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_3_1. WARNING:tensorflow:Output subtract_5 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_5. WARNING:tensorflow:Output layer_ll_4 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_4. WARNING:tensorflow:Output layer_ll_4_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_4_1. WARNING:tensorflow:Output subtract_6 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_6. WARNING:tensorflow:Output layer_ll_5 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_5. WARNING:tensorflow:Output layer_ll_5_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_5_1. WARNING:tensorflow:Output subtract_7 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_7. WARNING:tensorflow:Output lambda_11 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_11. WARNING:tensorflow:Output dense_13 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_13. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.724 (standardized score: 0.021), MAE: 0.701 ## exponential model MNLL: 0.723 (standardized score: 0.021), MAE: 0.710 ## piecewise constant model MNLL: 0.726 (standardized score: 0.023), MAE: 0.715 ## neural network based model MNLL: 0.722 (standardized score: 0.019), MAE: 0.702
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_stationary_renewal() # generate synthetic data: stationary renewal process.
print('#######################################')
print('## Stationary Renewal process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Stationary Renewal process ####################################### WARNING:tensorflow:Output subtract_8 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_8. WARNING:tensorflow:Output layer_ll_6 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_6. WARNING:tensorflow:Output layer_ll_6_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_6_1. WARNING:tensorflow:Output subtract_9 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_9. WARNING:tensorflow:Output layer_ll_7 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_7. WARNING:tensorflow:Output layer_ll_7_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_7_1. WARNING:tensorflow:Output subtract_10 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_10. WARNING:tensorflow:Output layer_ll_8 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_8. WARNING:tensorflow:Output layer_ll_8_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_8_1. WARNING:tensorflow:Output subtract_11 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_11. WARNING:tensorflow:Output lambda_17 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_17. WARNING:tensorflow:Output dense_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_20. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 1.033 (standardized score: 0.766), MAE: 1.164 ## exponential model MNLL: 0.565 (standardized score: 0.298), MAE: 0.996 ## piecewise constant model MNLL: 0.840 (standardized score: 0.574), MAE: 1.070 ## neural network based model MNLL: 0.268 (standardized score: 0.002), MAE: 0.972
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_nonstationary_renewal() # generate synthetic data: nonstationary renewal process.
print('#######################################')
print('## Non-tationary Renewal process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Non-tationary Renewal process ####################################### WARNING:tensorflow:Output subtract_12 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_12. WARNING:tensorflow:Output layer_ll_9 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_9. WARNING:tensorflow:Output layer_ll_9_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_9_1. WARNING:tensorflow:Output subtract_13 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_13. WARNING:tensorflow:Output layer_ll_10 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_10. WARNING:tensorflow:Output layer_ll_10_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_10_1. WARNING:tensorflow:Output subtract_14 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_14. WARNING:tensorflow:Output layer_ll_11 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_11. WARNING:tensorflow:Output layer_ll_11_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_11_1. WARNING:tensorflow:Output subtract_15 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_15. WARNING:tensorflow:Output lambda_23 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_23. WARNING:tensorflow:Output dense_27 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_27. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.708 (standardized score: 0.368), MAE: 0.441 ## exponential model MNLL: 0.695 (standardized score: 0.355), MAE: 0.436 ## piecewise constant model MNLL: 0.676 (standardized score: 0.336), MAE: 0.436 ## neural network based model MNLL: 0.367 (standardized score: 0.027), MAE: 0.403
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_self_correcting() # generate synthetic data: self-correcting process.
print('#######################################')
print('## Self-correcting process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Self-correcting process ####################################### WARNING:tensorflow:Output subtract_16 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_16. WARNING:tensorflow:Output layer_ll_12 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_12. WARNING:tensorflow:Output layer_ll_12_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_12_1. WARNING:tensorflow:Output subtract_17 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_17. WARNING:tensorflow:Output layer_ll_13 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_13. WARNING:tensorflow:Output layer_ll_13_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_13_1. WARNING:tensorflow:Output subtract_18 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_18. WARNING:tensorflow:Output layer_ll_14 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_14. WARNING:tensorflow:Output layer_ll_14_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_14_1. WARNING:tensorflow:Output subtract_19 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_19. WARNING:tensorflow:Output lambda_29 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_29. WARNING:tensorflow:Output dense_34 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_34. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.940 (standardized score: 0.180), MAE: 0.532 ## exponential model MNLL: 0.788 (standardized score: 0.029), MAE: 0.498 ## piecewise constant model MNLL: 0.798 (standardized score: 0.038), MAE: 0.498 ## neural network based model MNLL: 0.791 (standardized score: 0.032), MAE: 0.499
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_hawkes1() # generate synthetic data: hawkes1 process.
print('#######################################')
print('## Hawkes1 process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Hawkes1 process ####################################### WARNING:tensorflow:Output subtract_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_20. WARNING:tensorflow:Output layer_ll_15 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_15. WARNING:tensorflow:Output layer_ll_15_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_15_1. WARNING:tensorflow:Output subtract_21 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_21. WARNING:tensorflow:Output layer_ll_16 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_16. WARNING:tensorflow:Output layer_ll_16_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_16_1. WARNING:tensorflow:Output subtract_22 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_22. WARNING:tensorflow:Output layer_ll_17 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_17. WARNING:tensorflow:Output layer_ll_17_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_17_1. WARNING:tensorflow:Output subtract_23 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_23. WARNING:tensorflow:Output lambda_35 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_35. WARNING:tensorflow:Output dense_41 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_41. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.673 (standardized score: 0.213), MAE: 0.922 ## exponential model MNLL: 0.533 (standardized score: 0.073), MAE: 0.842 ## piecewise constant model MNLL: 0.473 (standardized score: 0.013), MAE: 0.837 ## neural network based model MNLL: 0.480 (standardized score: 0.020), MAE: 0.837
########## Hyper-parameters
time_step = 20 # truncation depth of a RNN
size_rnn = 64 # the number of units in a RNN
size_div = 128 # the number of sub-intervals of the piecewise constant function (piecewise constant model)
size_nn = 64 # the number of units in each hidden layer of the cumulative hazard function network (neural network based model)
size_layer = 2 # the number of hidden layers of the cumulative hazard function network (neural network based model)
########## Generate data
[T,score_ref] = generate_hawkes2() # generate synthetic data: hawkes2 process.
print('#######################################')
print('## Hawkes2 process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='const').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# exponential model
npp2 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='exp').set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# piecewise constant model
npp3 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='pc',size_div=size_div).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
# neural network based model
npp4 = NPP(time_step=time_step,size_rnn=size_rnn,type_hazard='NN',size_layer=size_layer,size_nn=size_nn).set_data(T).set_model().compile(lr=0.001).fit_eval(batch_size=256)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Hawkes2 process ####################################### WARNING:tensorflow:Output subtract_24 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_24. WARNING:tensorflow:Output layer_ll_18 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_18. WARNING:tensorflow:Output layer_ll_18_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_18_1. WARNING:tensorflow:Output subtract_25 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_25. WARNING:tensorflow:Output layer_ll_19 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_19. WARNING:tensorflow:Output layer_ll_19_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_19_1. WARNING:tensorflow:Output subtract_26 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_26. WARNING:tensorflow:Output layer_ll_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_20. WARNING:tensorflow:Output layer_ll_20_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_20_1. WARNING:tensorflow:Output subtract_27 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_27. WARNING:tensorflow:Output lambda_41 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_41. WARNING:tensorflow:Output dense_48 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_48. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.672 (standardized score: 0.671), MAE: 1.132 ## exponential model MNLL: 0.414 (standardized score: 0.413), MAE: 1.017 ## piecewise constant model MNLL: 0.206 (standardized score: 0.205), MAE: 1.003 ## neural network based model MNLL: 0.019 (standardized score: 0.018), MAE: 0.999
A python class "training" optimizes time_steps with the cross validation.
########## Generate data
[T,score_ref] = generate_stationary_poisson() # generate synthetic data: stationary Poisson process.
print('#######################################')
print('## Stationary Poisson process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Stationary Poisson process ####################################### WARNING:tensorflow:Output subtract_24 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_24. WARNING:tensorflow:Output layer_ll_18 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_18. WARNING:tensorflow:Output layer_ll_18_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_18_1. WARNING:tensorflow:Output subtract_25 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_25. WARNING:tensorflow:Output layer_ll_19 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_19. WARNING:tensorflow:Output layer_ll_19_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_19_1. WARNING:tensorflow:Output subtract_26 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_26. WARNING:tensorflow:Output layer_ll_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_20. WARNING:tensorflow:Output layer_ll_20_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_20_1. WARNING:tensorflow:Output subtract_27 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_27. WARNING:tensorflow:Output layer_ll_21 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_21. WARNING:tensorflow:Output layer_ll_21_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_21_1. WARNING:tensorflow:Output subtract_28 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_28. WARNING:tensorflow:Output layer_ll_22 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_22. WARNING:tensorflow:Output layer_ll_22_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_22_1. WARNING:tensorflow:Output subtract_29 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_29. WARNING:tensorflow:Output layer_ll_23 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_23. WARNING:tensorflow:Output layer_ll_23_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_23_1. WARNING:tensorflow:Output subtract_30 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_30. WARNING:tensorflow:Output layer_ll_24 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_24. WARNING:tensorflow:Output layer_ll_24_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_24_1. WARNING:tensorflow:Output subtract_31 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_31. WARNING:tensorflow:Output layer_ll_25 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_25. WARNING:tensorflow:Output layer_ll_25_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_25_1. WARNING:tensorflow:Output subtract_32 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_32. WARNING:tensorflow:Output layer_ll_26 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_26. WARNING:tensorflow:Output layer_ll_26_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_26_1. WARNING:tensorflow:Output subtract_33 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_33. WARNING:tensorflow:Output layer_ll_27 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_27. WARNING:tensorflow:Output layer_ll_27_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_27_1. WARNING:tensorflow:Output subtract_34 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_34. WARNING:tensorflow:Output layer_ll_28 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_28. WARNING:tensorflow:Output layer_ll_28_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_28_1. WARNING:tensorflow:Output subtract_35 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_35. WARNING:tensorflow:Output layer_ll_29 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_29. WARNING:tensorflow:Output layer_ll_29_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_29_1. WARNING:tensorflow:Output subtract_36 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_36. WARNING:tensorflow:Output lambda_50 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_50. WARNING:tensorflow:Output dense_57 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_57. WARNING:tensorflow:Output subtract_37 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_37. WARNING:tensorflow:Output lambda_53 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_53. WARNING:tensorflow:Output dense_61 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_61. WARNING:tensorflow:Output subtract_38 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_38. WARNING:tensorflow:Output lambda_56 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_56. WARNING:tensorflow:Output dense_65 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_65. WARNING:tensorflow:Output subtract_39 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_39. WARNING:tensorflow:Output lambda_59 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_59. WARNING:tensorflow:Output dense_69 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_69. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 1.009 (standardized score: 0.009), MAE: 0.703 ## exponential model MNLL: 1.009 (standardized score: 0.009), MAE: 0.702 ## piecewise constant model MNLL: 1.012 (standardized score: 0.012), MAE: 0.703 ## neural network based model MNLL: 1.011 (standardized score: 0.011), MAE: 0.703
########## Generate data
[T,score_ref] = generate_nonstationary_poisson() # generate synthetic data: nonstationary Poisson process.
print('#######################################')
print('## Non-stationary Poisson process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Non-stationary Poisson process ####################################### WARNING:tensorflow:Output subtract_40 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_40. WARNING:tensorflow:Output layer_ll_30 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_30. WARNING:tensorflow:Output layer_ll_30_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_30_1. WARNING:tensorflow:Output subtract_41 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_41. WARNING:tensorflow:Output layer_ll_31 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_31. WARNING:tensorflow:Output layer_ll_31_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_31_1. WARNING:tensorflow:Output subtract_42 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_42. WARNING:tensorflow:Output layer_ll_32 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_32. WARNING:tensorflow:Output layer_ll_32_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_32_1. WARNING:tensorflow:Output subtract_43 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_43. WARNING:tensorflow:Output layer_ll_33 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_33. WARNING:tensorflow:Output layer_ll_33_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_33_1. WARNING:tensorflow:Output subtract_44 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_44. WARNING:tensorflow:Output layer_ll_34 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_34. WARNING:tensorflow:Output layer_ll_34_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_34_1. WARNING:tensorflow:Output subtract_45 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_45. WARNING:tensorflow:Output layer_ll_35 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_35. WARNING:tensorflow:Output layer_ll_35_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_35_1. WARNING:tensorflow:Output subtract_46 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_46. WARNING:tensorflow:Output layer_ll_36 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_36. WARNING:tensorflow:Output layer_ll_36_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_36_1. WARNING:tensorflow:Output subtract_47 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_47. WARNING:tensorflow:Output layer_ll_37 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_37. WARNING:tensorflow:Output layer_ll_37_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_37_1. WARNING:tensorflow:Output subtract_48 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_48. WARNING:tensorflow:Output layer_ll_38 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_38. WARNING:tensorflow:Output layer_ll_38_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_38_1. WARNING:tensorflow:Output subtract_49 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_49. WARNING:tensorflow:Output layer_ll_39 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_39. WARNING:tensorflow:Output layer_ll_39_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_39_1. WARNING:tensorflow:Output subtract_50 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_50. WARNING:tensorflow:Output layer_ll_40 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_40. WARNING:tensorflow:Output layer_ll_40_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_40_1. WARNING:tensorflow:Output subtract_51 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_51. WARNING:tensorflow:Output layer_ll_41 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_41. WARNING:tensorflow:Output layer_ll_41_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_41_1. WARNING:tensorflow:Output subtract_52 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_52. WARNING:tensorflow:Output lambda_74 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_74. WARNING:tensorflow:Output dense_85 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_85. WARNING:tensorflow:Output subtract_53 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_53. WARNING:tensorflow:Output lambda_77 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_77. WARNING:tensorflow:Output dense_89 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_89. WARNING:tensorflow:Output subtract_54 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_54. WARNING:tensorflow:Output lambda_80 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_80. WARNING:tensorflow:Output dense_93 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_93. WARNING:tensorflow:Output subtract_55 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_55. WARNING:tensorflow:Output lambda_83 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_83. WARNING:tensorflow:Output dense_97 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_97. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.719 (standardized score: 0.016), MAE: 0.710 ## exponential model MNLL: 0.719 (standardized score: 0.016), MAE: 0.713 ## piecewise constant model MNLL: 0.718 (standardized score: 0.015), MAE: 0.711 ## neural network based model MNLL: 0.732 (standardized score: 0.030), MAE: 0.710
########## Generate data
[T,score_ref] = generate_stationary_renewal() # generate synthetic data: stationary renewal process.
print('#######################################')
print('## Stationary Renewal process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Stationary Renewal process ####################################### WARNING:tensorflow:Output subtract_56 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_56. WARNING:tensorflow:Output layer_ll_42 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_42. WARNING:tensorflow:Output layer_ll_42_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_42_1. WARNING:tensorflow:Output subtract_57 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_57. WARNING:tensorflow:Output layer_ll_43 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_43. WARNING:tensorflow:Output layer_ll_43_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_43_1. WARNING:tensorflow:Output subtract_58 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_58. WARNING:tensorflow:Output layer_ll_44 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_44. WARNING:tensorflow:Output layer_ll_44_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_44_1. WARNING:tensorflow:Output subtract_59 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_59. WARNING:tensorflow:Output layer_ll_45 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_45. WARNING:tensorflow:Output layer_ll_45_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_45_1. WARNING:tensorflow:Output subtract_60 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_60. WARNING:tensorflow:Output layer_ll_46 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_46. WARNING:tensorflow:Output layer_ll_46_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_46_1. WARNING:tensorflow:Output subtract_61 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_61. WARNING:tensorflow:Output layer_ll_47 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_47. WARNING:tensorflow:Output layer_ll_47_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_47_1. WARNING:tensorflow:Output subtract_62 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_62. WARNING:tensorflow:Output layer_ll_48 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_48. WARNING:tensorflow:Output layer_ll_48_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_48_1. WARNING:tensorflow:Output subtract_63 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_63. WARNING:tensorflow:Output layer_ll_49 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_49. WARNING:tensorflow:Output layer_ll_49_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_49_1. WARNING:tensorflow:Output subtract_64 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_64. WARNING:tensorflow:Output layer_ll_50 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_50. WARNING:tensorflow:Output layer_ll_50_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_50_1. WARNING:tensorflow:Output subtract_65 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_65. WARNING:tensorflow:Output layer_ll_51 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_51. WARNING:tensorflow:Output layer_ll_51_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_51_1. WARNING:tensorflow:Output subtract_66 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_66. WARNING:tensorflow:Output layer_ll_52 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_52. WARNING:tensorflow:Output layer_ll_52_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_52_1. WARNING:tensorflow:Output subtract_67 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_67. WARNING:tensorflow:Output layer_ll_53 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_53. WARNING:tensorflow:Output layer_ll_53_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_53_1. WARNING:tensorflow:Output subtract_68 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_68. WARNING:tensorflow:Output lambda_98 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_98. WARNING:tensorflow:Output dense_113 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_113. WARNING:tensorflow:Output subtract_69 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_69. WARNING:tensorflow:Output lambda_101 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_101. WARNING:tensorflow:Output dense_117 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_117. WARNING:tensorflow:Output subtract_70 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_70. WARNING:tensorflow:Output lambda_104 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_104. WARNING:tensorflow:Output dense_121 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_121. WARNING:tensorflow:Output subtract_71 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_71. WARNING:tensorflow:Output lambda_107 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_107. WARNING:tensorflow:Output dense_125 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_125. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.992 (standardized score: 0.737), MAE: 1.125 ## exponential model MNLL: 0.544 (standardized score: 0.289), MAE: 0.958 ## piecewise constant model MNLL: 0.903 (standardized score: 0.648), MAE: 1.077 ## neural network based model MNLL: 0.256 (standardized score: 0.001), MAE: 0.932
########## Generate data
[T,score_ref] = generate_nonstationary_renewal() # generate synthetic data: nonstationary renewal process.
print('#######################################')
print('## Non-stationary Renewal process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Non-stationary Renewal process ####################################### WARNING:tensorflow:Output subtract_72 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_72. WARNING:tensorflow:Output layer_ll_54 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_54. WARNING:tensorflow:Output layer_ll_54_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_54_1. WARNING:tensorflow:Output subtract_73 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_73. WARNING:tensorflow:Output layer_ll_55 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_55. WARNING:tensorflow:Output layer_ll_55_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_55_1. WARNING:tensorflow:Output subtract_74 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_74. WARNING:tensorflow:Output layer_ll_56 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_56. WARNING:tensorflow:Output layer_ll_56_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_56_1. WARNING:tensorflow:Output subtract_75 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_75. WARNING:tensorflow:Output layer_ll_57 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_57. WARNING:tensorflow:Output layer_ll_57_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_57_1. WARNING:tensorflow:Output subtract_76 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_76. WARNING:tensorflow:Output layer_ll_58 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_58. WARNING:tensorflow:Output layer_ll_58_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_58_1. WARNING:tensorflow:Output subtract_77 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_77. WARNING:tensorflow:Output layer_ll_59 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_59. WARNING:tensorflow:Output layer_ll_59_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_59_1. WARNING:tensorflow:Output subtract_78 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_78. WARNING:tensorflow:Output layer_ll_60 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_60. WARNING:tensorflow:Output layer_ll_60_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_60_1. WARNING:tensorflow:Output subtract_79 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_79. WARNING:tensorflow:Output layer_ll_61 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_61. WARNING:tensorflow:Output layer_ll_61_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_61_1. WARNING:tensorflow:Output subtract_80 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_80. WARNING:tensorflow:Output layer_ll_62 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_62. WARNING:tensorflow:Output layer_ll_62_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_62_1. WARNING:tensorflow:Output subtract_81 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_81. WARNING:tensorflow:Output layer_ll_63 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_63. WARNING:tensorflow:Output layer_ll_63_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_63_1. WARNING:tensorflow:Output subtract_82 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_82. WARNING:tensorflow:Output layer_ll_64 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_64. WARNING:tensorflow:Output layer_ll_64_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_64_1. WARNING:tensorflow:Output subtract_83 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_83. WARNING:tensorflow:Output layer_ll_65 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_65. WARNING:tensorflow:Output layer_ll_65_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_65_1. WARNING:tensorflow:Output subtract_84 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_84. WARNING:tensorflow:Output lambda_122 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_122. WARNING:tensorflow:Output dense_141 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_141. WARNING:tensorflow:Output subtract_85 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_85. WARNING:tensorflow:Output lambda_125 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_125. WARNING:tensorflow:Output dense_145 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_145. WARNING:tensorflow:Output subtract_86 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_86. WARNING:tensorflow:Output lambda_128 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_128. WARNING:tensorflow:Output dense_149 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_149. WARNING:tensorflow:Output subtract_87 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_87. WARNING:tensorflow:Output lambda_131 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_131. WARNING:tensorflow:Output dense_153 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_153. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.705 (standardized score: 0.375), MAE: 0.431 ## exponential model MNLL: 0.699 (standardized score: 0.369), MAE: 0.436 ## piecewise constant model MNLL: 0.683 (standardized score: 0.353), MAE: 0.422 ## neural network based model MNLL: 0.352 (standardized score: 0.021), MAE: 0.392
########## Generate data
[T,score_ref] = generate_self_correcting() # generate synthetic data: self-correcting process.
print('#######################################')
print('## Self-correcting process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Self-correcting process ####################################### WARNING:tensorflow:Output subtract_88 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_88. WARNING:tensorflow:Output layer_ll_66 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_66. WARNING:tensorflow:Output layer_ll_66_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_66_1. WARNING:tensorflow:Output subtract_89 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_89. WARNING:tensorflow:Output layer_ll_67 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_67. WARNING:tensorflow:Output layer_ll_67_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_67_1. WARNING:tensorflow:Output subtract_90 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_90. WARNING:tensorflow:Output layer_ll_68 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_68. WARNING:tensorflow:Output layer_ll_68_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_68_1. WARNING:tensorflow:Output subtract_91 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_91. WARNING:tensorflow:Output layer_ll_69 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_69. WARNING:tensorflow:Output layer_ll_69_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_69_1. WARNING:tensorflow:Output subtract_92 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_92. WARNING:tensorflow:Output layer_ll_70 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_70. WARNING:tensorflow:Output layer_ll_70_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_70_1. WARNING:tensorflow:Output subtract_93 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_93. WARNING:tensorflow:Output layer_ll_71 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_71. WARNING:tensorflow:Output layer_ll_71_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_71_1. WARNING:tensorflow:Output subtract_94 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_94. WARNING:tensorflow:Output layer_ll_72 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_72. WARNING:tensorflow:Output layer_ll_72_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_72_1. WARNING:tensorflow:Output subtract_95 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_95. WARNING:tensorflow:Output layer_ll_73 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_73. WARNING:tensorflow:Output layer_ll_73_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_73_1. WARNING:tensorflow:Output subtract_96 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_96. WARNING:tensorflow:Output layer_ll_74 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_74. WARNING:tensorflow:Output layer_ll_74_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_74_1. WARNING:tensorflow:Output subtract_97 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_97. WARNING:tensorflow:Output layer_ll_75 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_75. WARNING:tensorflow:Output layer_ll_75_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_75_1. WARNING:tensorflow:Output subtract_98 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_98. WARNING:tensorflow:Output layer_ll_76 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_76. WARNING:tensorflow:Output layer_ll_76_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_76_1. WARNING:tensorflow:Output subtract_99 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_99. WARNING:tensorflow:Output layer_ll_77 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_77. WARNING:tensorflow:Output layer_ll_77_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_77_1. WARNING:tensorflow:Output subtract_100 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_100. WARNING:tensorflow:Output lambda_146 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_146. WARNING:tensorflow:Output dense_169 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_169. WARNING:tensorflow:Output subtract_101 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_101. WARNING:tensorflow:Output lambda_149 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_149. WARNING:tensorflow:Output dense_173 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_173. WARNING:tensorflow:Output subtract_102 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_102. WARNING:tensorflow:Output lambda_152 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_152. WARNING:tensorflow:Output dense_177 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_177. WARNING:tensorflow:Output subtract_103 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_103. WARNING:tensorflow:Output lambda_155 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_155. WARNING:tensorflow:Output dense_181 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_181. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.936 (standardized score: 0.175), MAE: 0.538 ## exponential model MNLL: 0.789 (standardized score: 0.028), MAE: 0.499 ## piecewise constant model MNLL: 0.803 (standardized score: 0.042), MAE: 0.502 ## neural network based model MNLL: 0.796 (standardized score: 0.035), MAE: 0.502
########## Generate data
[T,score_ref] = generate_hawkes1() # generate synthetic data: hawkes1 process.
print('#######################################')
print('## Hawkes1 process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Hawkes1 process ####################################### WARNING:tensorflow:Output subtract_104 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_104. WARNING:tensorflow:Output layer_ll_78 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_78. WARNING:tensorflow:Output layer_ll_78_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_78_1. WARNING:tensorflow:Output subtract_105 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_105. WARNING:tensorflow:Output layer_ll_79 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_79. WARNING:tensorflow:Output layer_ll_79_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_79_1. WARNING:tensorflow:Output subtract_106 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_106. WARNING:tensorflow:Output layer_ll_80 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_80. WARNING:tensorflow:Output layer_ll_80_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_80_1. WARNING:tensorflow:Output subtract_107 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_107. WARNING:tensorflow:Output layer_ll_81 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_81. WARNING:tensorflow:Output layer_ll_81_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_81_1. WARNING:tensorflow:Output subtract_108 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_108. WARNING:tensorflow:Output layer_ll_82 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_82. WARNING:tensorflow:Output layer_ll_82_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_82_1. WARNING:tensorflow:Output subtract_109 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_109. WARNING:tensorflow:Output layer_ll_83 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_83. WARNING:tensorflow:Output layer_ll_83_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_83_1. WARNING:tensorflow:Output subtract_110 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_110. WARNING:tensorflow:Output layer_ll_84 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_84. WARNING:tensorflow:Output layer_ll_84_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_84_1. WARNING:tensorflow:Output subtract_111 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_111. WARNING:tensorflow:Output layer_ll_85 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_85. WARNING:tensorflow:Output layer_ll_85_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_85_1. WARNING:tensorflow:Output subtract_112 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_112. WARNING:tensorflow:Output layer_ll_86 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_86. WARNING:tensorflow:Output layer_ll_86_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_86_1. WARNING:tensorflow:Output subtract_113 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_113. WARNING:tensorflow:Output layer_ll_87 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_87. WARNING:tensorflow:Output layer_ll_87_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_87_1. WARNING:tensorflow:Output subtract_114 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_114. WARNING:tensorflow:Output layer_ll_88 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_88. WARNING:tensorflow:Output layer_ll_88_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_88_1. WARNING:tensorflow:Output subtract_115 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_115. WARNING:tensorflow:Output layer_ll_89 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_89. WARNING:tensorflow:Output layer_ll_89_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_89_1. WARNING:tensorflow:Output subtract_116 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_116. WARNING:tensorflow:Output lambda_170 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_170. WARNING:tensorflow:Output dense_197 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_197. WARNING:tensorflow:Output subtract_117 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_117. WARNING:tensorflow:Output lambda_173 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_173. WARNING:tensorflow:Output dense_201 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_201. WARNING:tensorflow:Output subtract_118 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_118. WARNING:tensorflow:Output lambda_176 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_176. WARNING:tensorflow:Output dense_205 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_205. WARNING:tensorflow:Output subtract_119 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_119. WARNING:tensorflow:Output lambda_179 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_179. WARNING:tensorflow:Output dense_209 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_209. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.762 (standardized score: 0.234), MAE: 0.989 ## exponential model MNLL: 0.610 (standardized score: 0.082), MAE: 0.918 ## piecewise constant model MNLL: 0.544 (standardized score: 0.016), MAE: 0.915 ## neural network based model MNLL: 0.537 (standardized score: 0.009), MAE: 0.913
########## Generate data
[T,score_ref] = generate_hawkes2() # generate synthetic data: hawkes2 process.
print('#######################################')
print('## Hawkes2 process')
print('#######################################')
print()
########## Train and evaluate the model (The following code will raise warnings, but please ignore them.)
# constant model
npp1 = training().training('const',T)
# exponential model
npp2 = training().training('exp',T)
# piecewise constant model
npp3 = training().training('pc',T)
# neural network based model
npp4 = training().training('NN',T)
print()
print('#######################################')
print('## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error)')
print('#######################################')
print('## constand model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp1.mnll,npp1.mnll-score_ref,npp1.mae) )
print('## exponential model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp2.mnll,npp2.mnll-score_ref,npp2.mae) )
print('## piecewise constant model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp3.mnll,npp3.mnll-score_ref,npp3.mae) )
print('## neural network based model \n MNLL: %.3f (standardized score: %.3f), MAE: %.3f' % (npp4.mnll,npp4.mnll-score_ref,npp4.mae) )
####################################### ## Hawkes2 process ####################################### WARNING:tensorflow:Output subtract_120 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_120. WARNING:tensorflow:Output layer_ll_90 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_90. WARNING:tensorflow:Output layer_ll_90_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_90_1. WARNING:tensorflow:Output subtract_121 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_121. WARNING:tensorflow:Output layer_ll_91 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_91. WARNING:tensorflow:Output layer_ll_91_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_91_1. WARNING:tensorflow:Output subtract_122 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_122. WARNING:tensorflow:Output layer_ll_92 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_92. WARNING:tensorflow:Output layer_ll_92_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_92_1. WARNING:tensorflow:Output subtract_123 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_123. WARNING:tensorflow:Output layer_ll_93 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_93. WARNING:tensorflow:Output layer_ll_93_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_93_1. WARNING:tensorflow:Output subtract_124 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_124. WARNING:tensorflow:Output layer_ll_94 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_94. WARNING:tensorflow:Output layer_ll_94_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_94_1. WARNING:tensorflow:Output subtract_125 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_125. WARNING:tensorflow:Output layer_ll_95 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_95. WARNING:tensorflow:Output layer_ll_95_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_95_1. WARNING:tensorflow:Output subtract_126 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_126. WARNING:tensorflow:Output layer_ll_96 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_96. WARNING:tensorflow:Output layer_ll_96_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_96_1. WARNING:tensorflow:Output subtract_127 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_127. WARNING:tensorflow:Output layer_ll_97 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_97. WARNING:tensorflow:Output layer_ll_97_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_97_1. WARNING:tensorflow:Output subtract_128 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_128. WARNING:tensorflow:Output layer_ll_98 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_98. WARNING:tensorflow:Output layer_ll_98_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_98_1. WARNING:tensorflow:Output subtract_129 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_129. WARNING:tensorflow:Output layer_ll_99 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_99. WARNING:tensorflow:Output layer_ll_99_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_99_1. WARNING:tensorflow:Output subtract_130 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_130. WARNING:tensorflow:Output layer_ll_100 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_100. WARNING:tensorflow:Output layer_ll_100_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_100_1. WARNING:tensorflow:Output subtract_131 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_131. WARNING:tensorflow:Output layer_ll_101 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_101. WARNING:tensorflow:Output layer_ll_101_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to layer_ll_101_1. WARNING:tensorflow:Output subtract_132 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_132. WARNING:tensorflow:Output lambda_194 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_194. WARNING:tensorflow:Output dense_225 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_225. WARNING:tensorflow:Output subtract_133 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_133. WARNING:tensorflow:Output lambda_197 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_197. WARNING:tensorflow:Output dense_229 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_229. WARNING:tensorflow:Output subtract_134 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_134. WARNING:tensorflow:Output lambda_200 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_200. WARNING:tensorflow:Output dense_233 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_233. WARNING:tensorflow:Output subtract_135 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to subtract_135. WARNING:tensorflow:Output lambda_203 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_203. WARNING:tensorflow:Output dense_237 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_237. ####################################### ## Performance for test data (MNLL: mean negative log-likelihood, MAE: mean absolute error) ####################################### ## constand model MNLL: 0.701 (standardized score: 0.654), MAE: 1.135 ## exponential model MNLL: 0.438 (standardized score: 0.392), MAE: 1.010 ## piecewise constant model MNLL: 0.237 (standardized score: 0.190), MAE: 0.997 ## neural network based model MNLL: 0.061 (standardized score: 0.015), MAE: 0.993