This note describes how to implement the proposed model with Keras.
import numpy as np
################################### for neural network modeling
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
A code to generate the other synthetic sequences used in the paper is given in "code.ipynb".
## simmulate a stationary Poisson process
T_train = np.random.exponential(size=80000).cumsum() # training data
T_test = np.random.exponential(size=20000).cumsum() # test data
The following keras model receives the event history and the elapsed time from the most recent event and outputs the hazard function and the cumulative hazard function.
## hyper parameters
time_step = 20 # truncation depth of RNN
size_rnn = 64 # the number of units in the RNN
size_nn = 64 # the nubmer of units in each hidden layer in the cumulative hazard function network
size_layer_chfn = 2 # the number of the hidden layers in the cumulative hazard function network
## mean and std of the log of the inter-event interval, which will be used for the data standardization
mu = np.log(np.ediff1d(T_train)).mean()
sigma = np.log(np.ediff1d(T_train)).std()
## kernel initializer for positive weights
def abs_glorot_uniform(shape, dtype=None, partition_info=None):
return K.abs(keras.initializers.glorot_uniform(seed=None)(shape,dtype=dtype))
## Inputs
event_history = layers.Input(shape=(time_step,1)) # input to RNN (event history)
elapsed_time = layers.Input(shape=(1,)) # input to cumulative hazard function network (the elapsed time from the most recent event)
## log-transformation and standardization
event_history_nmlz = layers.Lambda(lambda x: (K.log(x)-mu)/sigma )(event_history)
elapsed_time_nmlz = layers.Lambda(lambda x: (K.log(x)-mu)/sigma )(elapsed_time)
## RNN
output_rnn = layers.SimpleRNN(size_rnn,input_shape=(time_step,1),activation='tanh')(event_history_nmlz)
## the first hidden layer in the cummulative hazard function network
hidden_tau = layers.Dense(size_nn,kernel_initializer=abs_glorot_uniform,kernel_constraint=keras.constraints.NonNeg(),use_bias=False)(elapsed_time_nmlz) # elapsed time -> the 1st hidden layer, positive weights
hidden_rnn = layers.Dense(size_nn)(output_rnn) # rnn output -> the 1st hidden layer
hidden = layers.Lambda(lambda inputs: K.tanh(inputs[0]+inputs[1]) )([hidden_tau,hidden_rnn])
## the second and higher hidden layers
for i in range(size_layer_chfn-1):
hidden = layers.Dense(size_nn,activation='tanh',kernel_initializer=abs_glorot_uniform,kernel_constraint=keras.constraints.NonNeg())(hidden) # positive weights
## Outputs
Int_l = layers.Dense(1, activation='softplus',kernel_initializer=abs_glorot_uniform, kernel_constraint=keras.constraints.NonNeg() )(hidden) # cumulative hazard function, positive weights
l = layers.Lambda( lambda inputs: K.gradients(inputs[0],inputs[1])[0] )([Int_l,elapsed_time]) # hazard function
## define model
model = Model(inputs=[event_history,elapsed_time],outputs=[l,Int_l])
model.add_loss( -K.mean( K.log( 1e-10 + l ) - Int_l ) ) # set loss function to be the negative log-likelihood function
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers.
## format the input data
dT_train = np.ediff1d(T_train) # transform a series of timestamps to a series of interevent intervals: T_train -> dT_train
n = dT_train.shape[0]
input_RNN = np.array( [ dT_train[i:i+time_step] for i in range(n-time_step) ]).reshape(n-time_step,time_step,1)
input_CHFN = dT_train[-n+time_step:].reshape(n-time_step,1)
## training
model.compile(keras.optimizers.Adam(lr=0.001))
model.fit([input_RNN,input_CHFN],epochs=10,batch_size=256,validation_split=0.2) # In our study, we have set epochs = 100 and employed early stopping. Please see code.ipynb for more details.
WARNING:tensorflow:Output lambda_3 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to lambda_3. WARNING:tensorflow:Output dense_3 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense_3. Train on 63983 samples, validate on 15996 samples Epoch 1/10 63983/63983 [==============================] - 6s 86us/sample - loss: 1.3159 - val_loss: 1.0593 Epoch 2/10 63983/63983 [==============================] - 3s 54us/sample - loss: 1.0645 - val_loss: 1.0571 Epoch 3/10 63983/63983 [==============================] - 3s 53us/sample - loss: 1.0507 - val_loss: 1.0373 Epoch 4/10 63983/63983 [==============================] - 3s 54us/sample - loss: 1.0360 - val_loss: 1.0328 Epoch 5/10 63983/63983 [==============================] - 3s 52us/sample - loss: 1.0263 - val_loss: 1.0180 Epoch 6/10 63983/63983 [==============================] - 3s 52us/sample - loss: 1.0216 - val_loss: 1.0168 Epoch 7/10 63983/63983 [==============================] - 3s 53us/sample - loss: 1.0192 - val_loss: 1.0120 Epoch 8/10 63983/63983 [==============================] - 3s 52us/sample - loss: 1.0152 - val_loss: 1.0149 Epoch 9/10 63983/63983 [==============================] - 3s 53us/sample - loss: 1.0139 - val_loss: 1.0081 Epoch 10/10 63983/63983 [==============================] - 3s 54us/sample - loss: 1.0127 - val_loss: 1.0126
<tensorflow.python.keras.callbacks.History at 0x7f1ff043e5c0>
We here evaluate the performance of the trained model using the test data. We use the mean negative log-likelihood (MNLL) for the evaluation.
## format the input data
dT_test = np.ediff1d(T_test) # transform a series of timestamps to a series of interevent intervals: T_test -> dT_test
n = dT_test.shape[0]
input_RNN_test = np.array( [ dT_test[i:i+time_step] for i in range(n-time_step) ]).reshape(n-time_step,time_step,1)
input_CHFN_test = dT_test[-n+time_step:].reshape(n-time_step,1)
## testing
[l_test,Int_l_test] = model.predict([input_RNN_test,input_CHFN_test],batch_size=input_RNN_test.shape[0])
LL = np.log(l_test+1e-10) - Int_l_test # log-liklihood
print("Mean negative log-likelihood per event: ",-LL.mean())
Mean negative log-likelihood per event: 1.0063448
Next, we predict the timing of the next event at each time step using the median of the predictive distribution. We then evaluate the prediction in terms of mean absolute error (MAE).
# The median of the predictive distribution is determined using the bisection method.
x_left = 1e-4 * np.mean(dT_train) * np.ones_like(input_CHFN_test)
x_right = 100 * np.mean(dT_train) * np.ones_like(input_CHFN_test)
for i in range(13):
x_center = (x_left+x_right)/2
v = model.predict([input_RNN_test,x_center],batch_size=x_center.shape[0])[1]
x_left = np.where(v<np.log(2),x_center,x_left)
x_right = np.where(v>=np.log(2),x_center,x_right)
tau_pred = (x_left+x_right)/2 # predicted interevent interval
AE = np.abs(input_CHFN_test-tau_pred) # absolute error
print("Mean absolute error: ", AE.mean() )
Mean absolute error: 0.6963987826906138