Numbers Summation (using RNN LSTM)

Experiment overview

In this experiment we will use Recurrent Neural Network (RNN) to sum up two numbers (each number is in a range of [1, 99]). The summation expression (i.e. "1+45" or 37+68) that will be sent to the RNN input will be treated as a string (sequence of characters) and the output of the RNN will also be a string (i.e. "46" or "106") - the sequence of characters that will represent the result of summation. This is a "sequence-to-sequence" version of RNN implementation. We will use LSTM (Long Short-Term Memory) recurent neural network for this task.

For this experiment we will use Tensorflow v2 with its Keras API.

numbers_summation_rnn.png

Import dependencies

In [1]:
# Selecting Tensorflow version v2 (the command is relevant for Colab only).
%tensorflow_version 2.x
In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import platform
import random
import math
import datetime

print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)
Python version: 3.7.6
Tensorflow version: 2.1.0
Keras version: 2.2.4-tf

Configuring Tensorboard

We will use Tensorboard to debug the model later.

In [3]:
# Load the TensorBoard notebook extension.
# %reload_ext tensorboard
%load_ext tensorboard

Generate a dataset

In order to train a neural network we need to generate a training dataset which will consist of examples x (set of summation expressions) and labels y (set of correct answers for each expression). We will start with generating a set of numbers (not strings so far) and than we will convert the sets of numbers into strings.

In [4]:
dataset_size = 5000
sequence_length = 2
max_num = 100
In [5]:
# Generates summation sequences and summation results in form of vector if numbers.
def generate_sums(dataset_size, sequence_length, max_num):
    # Initial dataset states.
    x, y = [], []
    
    # Generating sums.
    for i in range(dataset_size):
        sequence = [random.randint(1, max_num) for _ in range(sequence_length)]
        x.append(sequence)
        y.append(sum(sequence))
    
    return x, y
In [6]:
x_train, y_train = generate_sums(
    dataset_size=dataset_size,
    sequence_length=sequence_length,
    max_num=max_num
)

print('x_train:\n', x_train[:3])
print()
print('y_train:\n', y_train[:3])
x_train:
 [[8, 79], [36, 39], [47, 15]]

y_train:
 [87, 75, 62]

Recurrent neural networks are normally deals with the sequences (of one or many elements) on the input. In order to treat a summation expression as a sequence we will convert array of numbers (i.e. [[13, 8], [85, 91], ...]) into strings (i.e. ['13+8 ', '85+91 ', ...]). We also want to add a space " " paddings to each string to make them all of the same length. After doing that we will be able to feed our RNN character by character (i.e. 1, then 3, then +, then 8, then and so on).

In [7]:
# Convert array of numbers for x and y into strings.
# Also it adds a space (" ") padding to strings to make them of the same length. 
def dataset_to_strings(x, y, max_num):
    # Initial dataset states.
    x_str, y_str = [], []
    
    sequnce_length = len(x[0])
    
    # Calculate the maximum length of equation (x) string (i.e. of "11+99")
    num_of_pluses = sequnce_length - 1
    num_of_chars_per_digit = math.ceil(math.log10(max_num + 1))
    max_x_length = sequnce_length * num_of_chars_per_digit + num_of_pluses
    
    # Calculate the maximum length of label (y) string (i.e. of "167")
    max_y_length = math.ceil(math.log10(sequnce_length * (max_num + 1)))
    
    # Add a space " " padding to equation strings to make them of the same length.
    for example in x:
        str_example = '+'.join([str(digit) for digit in example])
        str_example += ''.join([' ' for padding in range(max_x_length - len(str_example))])
        x_str.append(str_example)
    
    # Add a space " " padding to labels strings to make them of the same length.
    for label in y:
        str_example = str(label)
        str_example += ''.join([' ' for padding in range(max_y_length - len(str_example))])
        y_str.append(str_example)
    
    return x_str, y_str
In [8]:
x_train_str, y_train_str = dataset_to_strings(x_train, y_train, max_num)

print('x_train_str:\n', np.array(x_train_str[:3]))
print()
print('y_train_str:\n', np.array(y_train_str[:3]))
x_train_str:
 ['8+79   ' '36+39  ' '47+15  ']

y_train_str:
 ['87 ' '75 ' '62 ']
In [9]:
# Since we allow only numbers, plus sign and spaces the vocabulary looks pretty simple.
vocabulary = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', ' ']
In [10]:
# Python dictionary that will convert a character to its index in the vocabulary.
char_to_index = {char: index for index, char in enumerate(vocabulary)}

print(char_to_index)
{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '+': 10, ' ': 11}

Our RNN need to have a digits as an input to operate with (and not a strings). To convert summation expressions strings into number the first step we will do is to convert each character of each string into the position index of this character in the vocabulary.

In [11]:
# Converts x and y arrays of strings into array of char indices.
def dataset_to_indices(x, y, vocabulary):
    x_encoded, y_encoded = [], []
    
    char_to_index = {char: index for index, char in enumerate(vocabulary)}
    
    for example in x:
        example_encoded = [char_to_index[char] for char in example]
        x_encoded.append(example_encoded)
        
    for label in y:
        label_encoded = [char_to_index[char] for char in label]
        y_encoded.append(label_encoded)
        
    return x_encoded, y_encoded
In [12]:
x_train_encoded, y_train_encoded = dataset_to_indices(
    x_train_str,
    y_train_str,
    vocabulary
)

print('x_train_encoded:\n', np.array(x_train_encoded[:3]))
print()
print('y_train_encoded:\n', np.array(y_train_encoded[:3]))
x_train_encoded:
 [[ 8 10  7  9 11 11 11]
 [ 3  6 10  3  9 11 11]
 [ 4  7 10  1  5 11 11]]

y_train_encoded:
 [[ 8  7 11]
 [ 7  5 11]
 [ 6  2 11]]

To make our dataset even more understandable for our recurrent neural network we will convert each number into 0/1 one-hot vector (i.e. number 1 will be transformed into array [0 1 0 0 0 0 0 0 0 0 0 0]).

In [13]:
# Convert x and y sets of numbers into one-hot vectors.
def dataset_to_one_hot(x, y, vocabulary):
    x_encoded, y_encoded = [], []
    
    for example in x:
        pattern = []
        for index in example:
            vector = [0 for _ in range(len(vocabulary))]
            vector[index] = 1
            pattern.append(vector)
        x_encoded.append(pattern)
            
    for label in y:
        pattern = []
        for index in label:
            vector = [0 for _ in range(len(vocabulary))]
            vector[index] = 1
            pattern.append(vector)
        y_encoded.append(pattern)
        
    return x_encoded, y_encoded
In [14]:
x_train_one_hot, y_train_one_hot = dataset_to_one_hot(
    x_train_encoded,
    y_train_encoded,
    vocabulary
)

print('x_train_one_hot:\n', np.array(x_train_one_hot[:1]))
print()
print('y_train_one_hot:\n', np.array(y_train_one_hot[:1]))
x_train_one_hot:
 [[[0 0 0 0 0 0 0 0 1 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 1 0]
  [0 0 0 0 0 0 0 1 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 1 0 0]
  [0 0 0 0 0 0 0 0 0 0 0 1]
  [0 0 0 0 0 0 0 0 0 0 0 1]
  [0 0 0 0 0 0 0 0 0 0 0 1]]]

y_train_one_hot:
 [[[0 0 0 0 0 0 0 0 1 0 0 0]
  [0 0 0 0 0 0 0 1 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 0 1]]]

Now we may combine all the function together and create a one function that will do all the work to generate a dataset in the proper format for us.

In [15]:
# Generates a dataset.
def generate_dataset(dataset_size, sequence_length, max_num, vocabulary):
    # Generate integet sum sequences.
    x, y = generate_sums(dataset_size, sequence_length, max_num)
    # Convert integer sum sequences into strings.
    x, y = dataset_to_strings(x, y, max_num)
    # Encode each character to a char indices.
    x, y = dataset_to_indices(x, y, vocabulary)
    # Encode each index into one-hot vector.
    x, y = dataset_to_one_hot(x, y, vocabulary)
    # Return the data.
    return np.array(x), np.array(y)
In [16]:
x, y = generate_dataset(
    dataset_size,
    sequence_length,
    max_num,
    vocabulary
)

print('x:\n', x[:1])
print()
print('y:\n', y[:1])
x:
 [[[0 0 0 0 0 0 0 0 0 1 0 0]
  [0 0 0 1 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 1 0]
  [0 0 0 0 0 0 0 0 0 1 0 0]
  [0 0 0 0 1 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 0 1]
  [0 0 0 0 0 0 0 0 0 0 0 1]]]

y:
 [[[0 1 0 0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 1 0 0 0]
  [0 0 0 0 0 0 0 1 0 0 0 0]]]
In [17]:
print('x.shape: ', x.shape) # (input_sequences_num, input_sequence_length, supported_symbols_num)
print('y.shape: ', y.shape) # (output_sequences_num, output_sequence_length, supported_symbols_num)
x.shape:  (5000, 7, 12)
y.shape:  (5000, 3, 12)
In [18]:
# How many characters each summation expression has.
input_sequence_length = x.shape[1]

# How many characters the output sequence of the RNN has.
output_sequence_length = y.shape[1]

# The length of one-hot vector for each character in the input (should be the same as vocabulary_size).
supported_symbols_num = x.shape[2]

# The number of different characters our RNN network could work with (i.e. it understands only digits, "+" and " ").
vocabulary_size = len(vocabulary)

print('input_sequence_length: ', input_sequence_length)
print('output_sequence_length: ', output_sequence_length)
print('supported_symbols_num: ', supported_symbols_num)
print('vocabulary_size: ', vocabulary_size)
input_sequence_length:  7
output_sequence_length:  3
supported_symbols_num:  12
vocabulary_size:  12
In [19]:
# Converts a sequence (array) of one-hot encoded vectors back into the string based on the provided vocabulary.
def decode(sequence, vocabulary):
    index_to_char = {index: char for index, char in enumerate(vocabulary)}
    strings = []
    for char_vector in sequence:
        char = index_to_char[np.argmax(char_vector)]
        strings.append(char)
    return ''.join(strings)
In [20]:
decode(y[0], vocabulary)
Out[20]:
'187'

Build a model

We will use a [Sequential] TensorFlow model with LSTM layers. Logically our model will consist of encoder and decoder. Encoder will encode the input expression into an array of numbers. And then decoder will try to build an output result sequence by decoding the array of numbers from encoder.

In [21]:
epochs_num = 200
batch_size = 128
In [22]:
model = tf.keras.models.Sequential()

# Encoder
# -------

model.add(tf.keras.layers.LSTM(
    units=128,
    input_shape=(input_sequence_length, vocabulary_size),
    recurrent_initializer=tf.keras.initializers.GlorotNormal()
))

# Decoder
# -------

# We need this layer to match the encoder output shape with decoder input shape.
# Encoder outputs ONE vector of numbers but for decoder we need have output_sequence_length vectors.
model.add(tf.keras.layers.RepeatVector(
    n=output_sequence_length,
))

model.add(tf.keras.layers.LSTM(
    units=128,
    return_sequences=True,
    recurrent_initializer=tf.keras.initializers.GlorotNormal()
))

model.add(tf.keras.layers.TimeDistributed(
    layer=tf.keras.layers.Dense(
        units=vocabulary_size,
    )
))

model.add(tf.keras.layers.Activation(
    activation='softmax'
))
In [23]:
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm (LSTM)                  (None, 128)               72192     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 3, 128)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 3, 128)            131584    
_________________________________________________________________
time_distributed (TimeDistri (None, 3, 12)             1548      
_________________________________________________________________
activation (Activation)      (None, 3, 12)             0         
=================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
_________________________________________________________________
In [24]:
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    show_layer_names=True,
)
Out[24]:

Train a model

In [27]:
log_dir=".logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
In [30]:
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(
    optimizer=adam_optimizer,
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=['accuracy'],
)
In [31]:
history = model.fit(
    x=x,
    y=y,
    epochs=epochs_num,
    batch_size=batch_size,
    validation_split=0.1,
    callbacks=[tensorboard_callback]
)
Train on 4500 samples, validate on 500 samples
Epoch 1/200
4500/4500 [==============================] - 5s 1ms/sample - loss: 2.3158 - accuracy: 0.2441 - val_loss: 2.1193 - val_accuracy: 0.3800
Epoch 2/200
4500/4500 [==============================] - 1s 275us/sample - loss: 1.9693 - accuracy: 0.3836 - val_loss: 1.9278 - val_accuracy: 0.3827
Epoch 3/200
4500/4500 [==============================] - 2s 400us/sample - loss: 1.8441 - accuracy: 0.3830 - val_loss: 1.7944 - val_accuracy: 0.3900
Epoch 4/200
4500/4500 [==============================] - 2s 505us/sample - loss: 1.7833 - accuracy: 0.3804 - val_loss: 1.7343 - val_accuracy: 0.3880
Epoch 5/200
4500/4500 [==============================] - 1s 279us/sample - loss: 1.7441 - accuracy: 0.3870 - val_loss: 1.7016 - val_accuracy: 0.4013
Epoch 6/200
4500/4500 [==============================] - 1s 282us/sample - loss: 1.7044 - accuracy: 0.3932 - val_loss: 1.6716 - val_accuracy: 0.4080
Epoch 7/200
4500/4500 [==============================] - 1s 246us/sample - loss: 1.6768 - accuracy: 0.4010 - val_loss: 1.6591 - val_accuracy: 0.4067
Epoch 8/200
4500/4500 [==============================] - 1s 258us/sample - loss: 1.6766 - accuracy: 0.4054 - val_loss: 1.6476 - val_accuracy: 0.4133
Epoch 9/200
4500/4500 [==============================] - 1s 280us/sample - loss: 1.6474 - accuracy: 0.4137 - val_loss: 1.6307 - val_accuracy: 0.4173
Epoch 10/200
4500/4500 [==============================] - 1s 311us/sample - loss: 1.6424 - accuracy: 0.4112 - val_loss: 1.6255 - val_accuracy: 0.4267
Epoch 11/200
4500/4500 [==============================] - 1s 262us/sample - loss: 1.6164 - accuracy: 0.4192 - val_loss: 1.5914 - val_accuracy: 0.4267
Epoch 12/200
4500/4500 [==============================] - 1s 263us/sample - loss: 1.5813 - accuracy: 0.4272 - val_loss: 1.6045 - val_accuracy: 0.4200
Epoch 13/200
4500/4500 [==============================] - 1s 303us/sample - loss: 1.5218 - accuracy: 0.4425 - val_loss: 1.4479 - val_accuracy: 0.4593
Epoch 14/200
4500/4500 [==============================] - 1s 261us/sample - loss: 1.4390 - accuracy: 0.4649 - val_loss: 1.4137 - val_accuracy: 0.4633
Epoch 15/200
4500/4500 [==============================] - 1s 292us/sample - loss: 1.3753 - accuracy: 0.4837 - val_loss: 1.3286 - val_accuracy: 0.4913
Epoch 16/200
4500/4500 [==============================] - 1s 280us/sample - loss: 1.3659 - accuracy: 0.4901 - val_loss: 1.5145 - val_accuracy: 0.4467
Epoch 17/200
4500/4500 [==============================] - 1s 267us/sample - loss: 1.3072 - accuracy: 0.5157 - val_loss: 1.4779 - val_accuracy: 0.4580
Epoch 18/200
4500/4500 [==============================] - 2s 416us/sample - loss: 1.2754 - accuracy: 0.5273 - val_loss: 1.2320 - val_accuracy: 0.5387
Epoch 19/200
4500/4500 [==============================] - 2s 359us/sample - loss: 1.2228 - accuracy: 0.5448 - val_loss: 1.2073 - val_accuracy: 0.5420
Epoch 20/200
4500/4500 [==============================] - 2s 461us/sample - loss: 1.1880 - accuracy: 0.5607 - val_loss: 1.1793 - val_accuracy: 0.5527
Epoch 21/200
4500/4500 [==============================] - 2s 434us/sample - loss: 1.1723 - accuracy: 0.5654 - val_loss: 1.1488 - val_accuracy: 0.5873
Epoch 22/200
4500/4500 [==============================] - 2s 367us/sample - loss: 1.1638 - accuracy: 0.5629 - val_loss: 1.1786 - val_accuracy: 0.5687
Epoch 23/200
4500/4500 [==============================] - 2s 350us/sample - loss: 1.1214 - accuracy: 0.5963 - val_loss: 1.1130 - val_accuracy: 0.6000
Epoch 24/200
4500/4500 [==============================] - 2s 346us/sample - loss: 1.0922 - accuracy: 0.6054 - val_loss: 1.1018 - val_accuracy: 0.6047
Epoch 25/200
4500/4500 [==============================] - 1s 306us/sample - loss: 1.0732 - accuracy: 0.6150 - val_loss: 1.1317 - val_accuracy: 0.5953
Epoch 26/200
4500/4500 [==============================] - 1s 248us/sample - loss: 1.0563 - accuracy: 0.6198 - val_loss: 1.0575 - val_accuracy: 0.6187
Epoch 27/200
4500/4500 [==============================] - 1s 330us/sample - loss: 1.0480 - accuracy: 0.6173 - val_loss: 1.0423 - val_accuracy: 0.6167
Epoch 28/200
4500/4500 [==============================] - 2s 353us/sample - loss: 1.0038 - accuracy: 0.6484 - val_loss: 1.0420 - val_accuracy: 0.6127
Epoch 29/200
4500/4500 [==============================] - 1s 323us/sample - loss: 1.0229 - accuracy: 0.6237 - val_loss: 1.0975 - val_accuracy: 0.5820
Epoch 30/200
4500/4500 [==============================] - 2s 358us/sample - loss: 0.9974 - accuracy: 0.6374 - val_loss: 1.0929 - val_accuracy: 0.5780
Epoch 31/200
4500/4500 [==============================] - 2s 392us/sample - loss: 0.9672 - accuracy: 0.6513 - val_loss: 0.9835 - val_accuracy: 0.6293
Epoch 32/200
4500/4500 [==============================] - 2s 344us/sample - loss: 0.9327 - accuracy: 0.6676 - val_loss: 0.9199 - val_accuracy: 0.6687
Epoch 33/200
4500/4500 [==============================] - 2s 362us/sample - loss: 0.8927 - accuracy: 0.6875 - val_loss: 0.8986 - val_accuracy: 0.6733
Epoch 34/200
4500/4500 [==============================] - 2s 489us/sample - loss: 0.8669 - accuracy: 0.6944 - val_loss: 0.8837 - val_accuracy: 0.6727
Epoch 35/200
4500/4500 [==============================] - 2s 334us/sample - loss: 0.8370 - accuracy: 0.7062 - val_loss: 0.8414 - val_accuracy: 0.7007
Epoch 36/200
4500/4500 [==============================] - 1s 297us/sample - loss: 0.8041 - accuracy: 0.7193 - val_loss: 0.8041 - val_accuracy: 0.7220
Epoch 37/200
4500/4500 [==============================] - 2s 384us/sample - loss: 0.7697 - accuracy: 0.7292 - val_loss: 0.8067 - val_accuracy: 0.7080
Epoch 38/200
4500/4500 [==============================] - 2s 450us/sample - loss: 0.7470 - accuracy: 0.7353 - val_loss: 0.7434 - val_accuracy: 0.7353
Epoch 39/200
4500/4500 [==============================] - 1s 299us/sample - loss: 0.7102 - accuracy: 0.7525 - val_loss: 0.7253 - val_accuracy: 0.7373
Epoch 40/200
4500/4500 [==============================] - 1s 258us/sample - loss: 0.6950 - accuracy: 0.7520 - val_loss: 0.7477 - val_accuracy: 0.7200
Epoch 41/200
4500/4500 [==============================] - 2s 359us/sample - loss: 0.6823 - accuracy: 0.7551 - val_loss: 0.6489 - val_accuracy: 0.7813
Epoch 42/200
4500/4500 [==============================] - 1s 312us/sample - loss: 0.6113 - accuracy: 0.7959 - val_loss: 0.6131 - val_accuracy: 0.7960
Epoch 43/200
4500/4500 [==============================] - 2s 422us/sample - loss: 0.5712 - accuracy: 0.8155 - val_loss: 0.5764 - val_accuracy: 0.8000
Epoch 44/200
4500/4500 [==============================] - 2s 400us/sample - loss: 0.5370 - accuracy: 0.8361 - val_loss: 0.5564 - val_accuracy: 0.8067
Epoch 45/200
4500/4500 [==============================] - 1s 303us/sample - loss: 0.4961 - accuracy: 0.8553 - val_loss: 0.5110 - val_accuracy: 0.8360
Epoch 46/200
4500/4500 [==============================] - 1s 251us/sample - loss: 0.4710 - accuracy: 0.8637 - val_loss: 0.4872 - val_accuracy: 0.8453
Epoch 47/200
4500/4500 [==============================] - 1s 256us/sample - loss: 0.4383 - accuracy: 0.8778 - val_loss: 0.4712 - val_accuracy: 0.8447
Epoch 48/200
4500/4500 [==============================] - 1s 317us/sample - loss: 0.4155 - accuracy: 0.8886 - val_loss: 0.4467 - val_accuracy: 0.8600
Epoch 49/200
4500/4500 [==============================] - 2s 374us/sample - loss: 0.4020 - accuracy: 0.8912 - val_loss: 0.4096 - val_accuracy: 0.8880
Epoch 50/200
4500/4500 [==============================] - 1s 318us/sample - loss: 0.3830 - accuracy: 0.8979 - val_loss: 0.4263 - val_accuracy: 0.8733
Epoch 51/200
4500/4500 [==============================] - 1s 264us/sample - loss: 0.3631 - accuracy: 0.9044 - val_loss: 0.3707 - val_accuracy: 0.9060
Epoch 52/200
4500/4500 [==============================] - 1s 320us/sample - loss: 0.3297 - accuracy: 0.9261 - val_loss: 0.3472 - val_accuracy: 0.9207
Epoch 53/200
4500/4500 [==============================] - 1s 278us/sample - loss: 0.3064 - accuracy: 0.9381 - val_loss: 0.3332 - val_accuracy: 0.9227
Epoch 54/200
4500/4500 [==============================] - 1s 252us/sample - loss: 0.2881 - accuracy: 0.9440 - val_loss: 0.3403 - val_accuracy: 0.9240
Epoch 55/200
4500/4500 [==============================] - 2s 345us/sample - loss: 0.2742 - accuracy: 0.9452 - val_loss: 0.3117 - val_accuracy: 0.9320
Epoch 56/200
4500/4500 [==============================] - 2s 466us/sample - loss: 0.2592 - accuracy: 0.9536 - val_loss: 0.2773 - val_accuracy: 0.9413
Epoch 57/200
4500/4500 [==============================] - 2s 467us/sample - loss: 0.2463 - accuracy: 0.9549 - val_loss: 0.2671 - val_accuracy: 0.9507
Epoch 58/200
4500/4500 [==============================] - 2s 484us/sample - loss: 0.2407 - accuracy: 0.9557 - val_loss: 0.3212 - val_accuracy: 0.8940
Epoch 59/200
4500/4500 [==============================] - 2s 424us/sample - loss: 0.2791 - accuracy: 0.9250 - val_loss: 0.2557 - val_accuracy: 0.9380
Epoch 60/200
4500/4500 [==============================] - 2s 428us/sample - loss: 0.2160 - accuracy: 0.9613 - val_loss: 0.2234 - val_accuracy: 0.9600
Epoch 61/200
4500/4500 [==============================] - 2s 455us/sample - loss: 0.1951 - accuracy: 0.9708 - val_loss: 0.2309 - val_accuracy: 0.9513
Epoch 62/200
4500/4500 [==============================] - 2s 364us/sample - loss: 0.1822 - accuracy: 0.9746 - val_loss: 0.1973 - val_accuracy: 0.9653
Epoch 63/200
4500/4500 [==============================] - 2s 405us/sample - loss: 0.1660 - accuracy: 0.9798 - val_loss: 0.1984 - val_accuracy: 0.9593
Epoch 64/200
4500/4500 [==============================] - 2s 369us/sample - loss: 0.1581 - accuracy: 0.9804 - val_loss: 0.1806 - val_accuracy: 0.9673
Epoch 65/200
4500/4500 [==============================] - 2s 443us/sample - loss: 0.1504 - accuracy: 0.9800 - val_loss: 0.1723 - val_accuracy: 0.9753
Epoch 66/200
4500/4500 [==============================] - 2s 366us/sample - loss: 0.1391 - accuracy: 0.9847 - val_loss: 0.1605 - val_accuracy: 0.9753
Epoch 67/200
4500/4500 [==============================] - 2s 395us/sample - loss: 0.1336 - accuracy: 0.9836 - val_loss: 0.1561 - val_accuracy: 0.9713
Epoch 68/200
4500/4500 [==============================] - 2s 419us/sample - loss: 0.1251 - accuracy: 0.9861 - val_loss: 0.1406 - val_accuracy: 0.9773
Epoch 69/200
4500/4500 [==============================] - 2s 461us/sample - loss: 0.1158 - accuracy: 0.9882 - val_loss: 0.1333 - val_accuracy: 0.9813
Epoch 70/200
4500/4500 [==============================] - 1s 318us/sample - loss: 0.1215 - accuracy: 0.9856 - val_loss: 0.1617 - val_accuracy: 0.9627
Epoch 71/200
4500/4500 [==============================] - 2s 390us/sample - loss: 0.1783 - accuracy: 0.9524 - val_loss: 0.1554 - val_accuracy: 0.9693
Epoch 72/200
4500/4500 [==============================] - 1s 300us/sample - loss: 0.1315 - accuracy: 0.9769 - val_loss: 0.1286 - val_accuracy: 0.9760
Epoch 73/200
4500/4500 [==============================] - 1s 272us/sample - loss: 0.1033 - accuracy: 0.9891 - val_loss: 0.1184 - val_accuracy: 0.9847
Epoch 74/200
4500/4500 [==============================] - 1s 224us/sample - loss: 0.0910 - accuracy: 0.9919 - val_loss: 0.1070 - val_accuracy: 0.9847
Epoch 75/200
4500/4500 [==============================] - 2s 382us/sample - loss: 0.0829 - accuracy: 0.9939 - val_loss: 0.1076 - val_accuracy: 0.9860
Epoch 76/200
4500/4500 [==============================] - 2s 391us/sample - loss: 0.0806 - accuracy: 0.9928 - val_loss: 0.1017 - val_accuracy: 0.9900
Epoch 77/200
4500/4500 [==============================] - 1s 233us/sample - loss: 0.0783 - accuracy: 0.9933 - val_loss: 0.1271 - val_accuracy: 0.9727
Epoch 78/200
4500/4500 [==============================] - 1s 322us/sample - loss: 0.2431 - accuracy: 0.9224 - val_loss: 0.1460 - val_accuracy: 0.9673
Epoch 79/200
4500/4500 [==============================] - 1s 253us/sample - loss: 0.0956 - accuracy: 0.9849 - val_loss: 0.0942 - val_accuracy: 0.9853
Epoch 80/200
4500/4500 [==============================] - 1s 331us/sample - loss: 0.0699 - accuracy: 0.9954 - val_loss: 0.0861 - val_accuracy: 0.9907
Epoch 81/200
4500/4500 [==============================] - 2s 372us/sample - loss: 0.0629 - accuracy: 0.9960 - val_loss: 0.0819 - val_accuracy: 0.9907
Epoch 82/200
4500/4500 [==============================] - 2s 435us/sample - loss: 0.0601 - accuracy: 0.9961 - val_loss: 0.0817 - val_accuracy: 0.9867
Epoch 83/200
4500/4500 [==============================] - 2s 391us/sample - loss: 0.0570 - accuracy: 0.9966 - val_loss: 0.0733 - val_accuracy: 0.9913
Epoch 84/200
4500/4500 [==============================] - 2s 546us/sample - loss: 0.0536 - accuracy: 0.9973 - val_loss: 0.0730 - val_accuracy: 0.9900
Epoch 85/200
4500/4500 [==============================] - 2s 431us/sample - loss: 0.0508 - accuracy: 0.9976 - val_loss: 0.0716 - val_accuracy: 0.9893
Epoch 86/200
4500/4500 [==============================] - 2s 524us/sample - loss: 0.0491 - accuracy: 0.9973 - val_loss: 0.0705 - val_accuracy: 0.9900
Epoch 87/200
4500/4500 [==============================] - 2s 548us/sample - loss: 0.0477 - accuracy: 0.9976 - val_loss: 0.0680 - val_accuracy: 0.9913
Epoch 88/200
4500/4500 [==============================] - 3s 568us/sample - loss: 0.0467 - accuracy: 0.9972 - val_loss: 0.0629 - val_accuracy: 0.9920
Epoch 89/200
4500/4500 [==============================] - 2s 522us/sample - loss: 0.0426 - accuracy: 0.9981 - val_loss: 0.0616 - val_accuracy: 0.9927
Epoch 90/200
4500/4500 [==============================] - 3s 578us/sample - loss: 0.0399 - accuracy: 0.9984 - val_loss: 0.0612 - val_accuracy: 0.9907
Epoch 91/200
4500/4500 [==============================] - 2s 544us/sample - loss: 0.0384 - accuracy: 0.9982 - val_loss: 0.0584 - val_accuracy: 0.9913
Epoch 92/200
4500/4500 [==============================] - 2s 504us/sample - loss: 0.0388 - accuracy: 0.9982 - val_loss: 0.0556 - val_accuracy: 0.9920
Epoch 93/200
4500/4500 [==============================] - 2s 526us/sample - loss: 0.0374 - accuracy: 0.9981 - val_loss: 0.0543 - val_accuracy: 0.9927
Epoch 94/200
4500/4500 [==============================] - 2s 534us/sample - loss: 0.0342 - accuracy: 0.9992 - val_loss: 0.0514 - val_accuracy: 0.9933
Epoch 95/200
4500/4500 [==============================] - 3s 556us/sample - loss: 0.0321 - accuracy: 0.9989 - val_loss: 0.0534 - val_accuracy: 0.9927
Epoch 96/200
4500/4500 [==============================] - 2s 502us/sample - loss: 0.0332 - accuracy: 0.9984 - val_loss: 0.0566 - val_accuracy: 0.9907
Epoch 97/200
4500/4500 [==============================] - 2s 547us/sample - loss: 0.2775 - accuracy: 0.9104 - val_loss: 0.1785 - val_accuracy: 0.9353
Epoch 98/200
4500/4500 [==============================] - 2s 376us/sample - loss: 0.0776 - accuracy: 0.9842 - val_loss: 0.0620 - val_accuracy: 0.9900
Epoch 99/200
4500/4500 [==============================] - 2s 365us/sample - loss: 0.0341 - accuracy: 0.9988 - val_loss: 0.0495 - val_accuracy: 0.9940
Epoch 100/200
4500/4500 [==============================] - 2s 407us/sample - loss: 0.0298 - accuracy: 0.9994 - val_loss: 0.0460 - val_accuracy: 0.9953
Epoch 101/200
4500/4500 [==============================] - 2s 524us/sample - loss: 0.0282 - accuracy: 0.9993 - val_loss: 0.0455 - val_accuracy: 0.9953
Epoch 102/200
4500/4500 [==============================] - 3s 559us/sample - loss: 0.0262 - accuracy: 0.9999 - val_loss: 0.0440 - val_accuracy: 0.9947
Epoch 103/200
4500/4500 [==============================] - 2s 486us/sample - loss: 0.0256 - accuracy: 0.9996 - val_loss: 0.0422 - val_accuracy: 0.9940
Epoch 104/200
4500/4500 [==============================] - 2s 375us/sample - loss: 0.0241 - accuracy: 0.9996 - val_loss: 0.0411 - val_accuracy: 0.9947
Epoch 105/200
4500/4500 [==============================] - 2s 398us/sample - loss: 0.0236 - accuracy: 0.9996 - val_loss: 0.0414 - val_accuracy: 0.9940
Epoch 106/200
4500/4500 [==============================] - 2s 474us/sample - loss: 0.0230 - accuracy: 0.9994 - val_loss: 0.0397 - val_accuracy: 0.9953
Epoch 107/200
4500/4500 [==============================] - 2s 354us/sample - loss: 0.0215 - accuracy: 0.9996 - val_loss: 0.0379 - val_accuracy: 0.9953
Epoch 108/200
4500/4500 [==============================] - 2s 378us/sample - loss: 0.0206 - accuracy: 0.9998 - val_loss: 0.0369 - val_accuracy: 0.9953
Epoch 109/200
4500/4500 [==============================] - 2s 361us/sample - loss: 0.0205 - accuracy: 0.9994 - val_loss: 0.0359 - val_accuracy: 0.9947
Epoch 110/200
4500/4500 [==============================] - 1s 261us/sample - loss: 0.0190 - accuracy: 0.9999 - val_loss: 0.0359 - val_accuracy: 0.9947
Epoch 111/200
4500/4500 [==============================] - 1s 264us/sample - loss: 0.0185 - accuracy: 0.9999 - val_loss: 0.0347 - val_accuracy: 0.9947
Epoch 112/200
4500/4500 [==============================] - 1s 255us/sample - loss: 0.0183 - accuracy: 0.9996 - val_loss: 0.0343 - val_accuracy: 0.9953
Epoch 113/200
4500/4500 [==============================] - 1s 279us/sample - loss: 0.0177 - accuracy: 0.9997 - val_loss: 0.0352 - val_accuracy: 0.9953
Epoch 114/200
4500/4500 [==============================] - 2s 344us/sample - loss: 0.0168 - accuracy: 0.9999 - val_loss: 0.0322 - val_accuracy: 0.9947
Epoch 115/200
4500/4500 [==============================] - 1s 315us/sample - loss: 0.0161 - accuracy: 0.9998 - val_loss: 0.0316 - val_accuracy: 0.9953
Epoch 116/200
4500/4500 [==============================] - 2s 362us/sample - loss: 0.0153 - accuracy: 1.0000 - val_loss: 0.0306 - val_accuracy: 0.9947
Epoch 117/200
4500/4500 [==============================] - 2s 382us/sample - loss: 0.0149 - accuracy: 0.9999 - val_loss: 0.0303 - val_accuracy: 0.9953
Epoch 118/200
4500/4500 [==============================] - 2s 384us/sample - loss: 0.0145 - accuracy: 0.9999 - val_loss: 0.0302 - val_accuracy: 0.9953
Epoch 119/200
4500/4500 [==============================] - 2s 354us/sample - loss: 0.0138 - accuracy: 0.9999 - val_loss: 0.0291 - val_accuracy: 0.9953
Epoch 120/200
4500/4500 [==============================] - 1s 328us/sample - loss: 0.0142 - accuracy: 0.9999 - val_loss: 0.0314 - val_accuracy: 0.9947
Epoch 121/200
4500/4500 [==============================] - 1s 277us/sample - loss: 0.0133 - accuracy: 0.9999 - val_loss: 0.0287 - val_accuracy: 0.9960
Epoch 122/200
4500/4500 [==============================] - 1s 275us/sample - loss: 0.0127 - accuracy: 0.9999 - val_loss: 0.0277 - val_accuracy: 0.9953
Epoch 123/200
4500/4500 [==============================] - 1s 273us/sample - loss: 0.0123 - accuracy: 1.0000 - val_loss: 0.0282 - val_accuracy: 0.9953
Epoch 124/200
4500/4500 [==============================] - 1s 272us/sample - loss: 0.0116 - accuracy: 0.9999 - val_loss: 0.0271 - val_accuracy: 0.9953
Epoch 125/200
4500/4500 [==============================] - 1s 290us/sample - loss: 0.0116 - accuracy: 1.0000 - val_loss: 0.0270 - val_accuracy: 0.9960
Epoch 126/200
4500/4500 [==============================] - 1s 300us/sample - loss: 0.0123 - accuracy: 0.9997 - val_loss: 0.0275 - val_accuracy: 0.9960
Epoch 127/200
4500/4500 [==============================] - 1s 286us/sample - loss: 0.0111 - accuracy: 0.9999 - val_loss: 0.0262 - val_accuracy: 0.9953
Epoch 128/200
4500/4500 [==============================] - 1s 274us/sample - loss: 0.0115 - accuracy: 0.9996 - val_loss: 0.0372 - val_accuracy: 0.9900
Epoch 129/200
4500/4500 [==============================] - 1s 289us/sample - loss: 0.5044 - accuracy: 0.8552 - val_loss: 0.2934 - val_accuracy: 0.8887
Epoch 130/200
4500/4500 [==============================] - 1s 332us/sample - loss: 0.1148 - accuracy: 0.9691 - val_loss: 0.0596 - val_accuracy: 0.9927
Epoch 131/200
4500/4500 [==============================] - 1s 285us/sample - loss: 0.0319 - accuracy: 0.9979 - val_loss: 0.0359 - val_accuracy: 0.9953
Epoch 132/200
4500/4500 [==============================] - 1s 269us/sample - loss: 0.0189 - accuracy: 0.9997 - val_loss: 0.0333 - val_accuracy: 0.9947
Epoch 133/200
4500/4500 [==============================] - 1s 260us/sample - loss: 0.0155 - accuracy: 0.9999 - val_loss: 0.0300 - val_accuracy: 0.9960
Epoch 134/200
4500/4500 [==============================] - 1s 263us/sample - loss: 0.0137 - accuracy: 1.0000 - val_loss: 0.0292 - val_accuracy: 0.9953
Epoch 135/200
4500/4500 [==============================] - 1s 265us/sample - loss: 0.0127 - accuracy: 0.9999 - val_loss: 0.0281 - val_accuracy: 0.9947
Epoch 136/200
4500/4500 [==============================] - 1s 262us/sample - loss: 0.0118 - accuracy: 1.0000 - val_loss: 0.0268 - val_accuracy: 0.9953
Epoch 137/200
4500/4500 [==============================] - 1s 265us/sample - loss: 0.0111 - accuracy: 1.0000 - val_loss: 0.0256 - val_accuracy: 0.9953
Epoch 138/200
4500/4500 [==============================] - 1s 269us/sample - loss: 0.0106 - accuracy: 1.0000 - val_loss: 0.0261 - val_accuracy: 0.9947
Epoch 139/200
4500/4500 [==============================] - 1s 258us/sample - loss: 0.0103 - accuracy: 1.0000 - val_loss: 0.0248 - val_accuracy: 0.9960
Epoch 140/200
4500/4500 [==============================] - 1s 261us/sample - loss: 0.0097 - accuracy: 1.0000 - val_loss: 0.0246 - val_accuracy: 0.9953
Epoch 141/200
4500/4500 [==============================] - 1s 277us/sample - loss: 0.0093 - accuracy: 1.0000 - val_loss: 0.0238 - val_accuracy: 0.9960
Epoch 142/200
4500/4500 [==============================] - 1s 263us/sample - loss: 0.0089 - accuracy: 1.0000 - val_loss: 0.0237 - val_accuracy: 0.9960
Epoch 143/200
4500/4500 [==============================] - 1s 261us/sample - loss: 0.0087 - accuracy: 1.0000 - val_loss: 0.0231 - val_accuracy: 0.9960
Epoch 144/200
4500/4500 [==============================] - 1s 263us/sample - loss: 0.0084 - accuracy: 1.0000 - val_loss: 0.0228 - val_accuracy: 0.9960
Epoch 145/200
4500/4500 [==============================] - 1s 258us/sample - loss: 0.0081 - accuracy: 1.0000 - val_loss: 0.0223 - val_accuracy: 0.9960
Epoch 146/200
4500/4500 [==============================] - 1s 262us/sample - loss: 0.0079 - accuracy: 1.0000 - val_loss: 0.0221 - val_accuracy: 0.9960
Epoch 147/200
4500/4500 [==============================] - 1s 290us/sample - loss: 0.0076 - accuracy: 1.0000 - val_loss: 0.0218 - val_accuracy: 0.9960
Epoch 148/200
4500/4500 [==============================] - 2s 350us/sample - loss: 0.0074 - accuracy: 1.0000 - val_loss: 0.0224 - val_accuracy: 0.9947
Epoch 149/200
4500/4500 [==============================] - 2s 400us/sample - loss: 0.0071 - accuracy: 1.0000 - val_loss: 0.0216 - val_accuracy: 0.9960
Epoch 150/200
4500/4500 [==============================] - 2s 366us/sample - loss: 0.0070 - accuracy: 1.0000 - val_loss: 0.0214 - val_accuracy: 0.9960
Epoch 151/200
4500/4500 [==============================] - 2s 392us/sample - loss: 0.0069 - accuracy: 1.0000 - val_loss: 0.0210 - val_accuracy: 0.9953
Epoch 152/200
4500/4500 [==============================] - 2s 369us/sample - loss: 0.0066 - accuracy: 1.0000 - val_loss: 0.0208 - val_accuracy: 0.9960
Epoch 153/200
4500/4500 [==============================] - 2s 365us/sample - loss: 0.0065 - accuracy: 1.0000 - val_loss: 0.0209 - val_accuracy: 0.9953
Epoch 154/200
4500/4500 [==============================] - 2s 412us/sample - loss: 0.0063 - accuracy: 1.0000 - val_loss: 0.0204 - val_accuracy: 0.9960
Epoch 155/200
4500/4500 [==============================] - 2s 379us/sample - loss: 0.0060 - accuracy: 1.0000 - val_loss: 0.0203 - val_accuracy: 0.9960
Epoch 156/200
4500/4500 [==============================] - 2s 377us/sample - loss: 0.0059 - accuracy: 1.0000 - val_loss: 0.0202 - val_accuracy: 0.9960
Epoch 157/200
4500/4500 [==============================] - 2s 382us/sample - loss: 0.0058 - accuracy: 1.0000 - val_loss: 0.0200 - val_accuracy: 0.9960
Epoch 158/200
4500/4500 [==============================] - 2s 388us/sample - loss: 0.0060 - accuracy: 0.9999 - val_loss: 0.0197 - val_accuracy: 0.9960
Epoch 159/200
4500/4500 [==============================] - 2s 391us/sample - loss: 0.0055 - accuracy: 1.0000 - val_loss: 0.0198 - val_accuracy: 0.9960
Epoch 160/200
4500/4500 [==============================] - 1s 248us/sample - loss: 0.0055 - accuracy: 1.0000 - val_loss: 0.0188 - val_accuracy: 0.9960
Epoch 161/200
4500/4500 [==============================] - 1s 222us/sample - loss: 0.0053 - accuracy: 1.0000 - val_loss: 0.0191 - val_accuracy: 0.9960
Epoch 162/200
4500/4500 [==============================] - 1s 217us/sample - loss: 0.0051 - accuracy: 1.0000 - val_loss: 0.0189 - val_accuracy: 0.9960
Epoch 163/200
4500/4500 [==============================] - 1s 319us/sample - loss: 0.0049 - accuracy: 1.0000 - val_loss: 0.0188 - val_accuracy: 0.9960
Epoch 164/200
4500/4500 [==============================] - 2s 423us/sample - loss: 0.0048 - accuracy: 1.0000 - val_loss: 0.0186 - val_accuracy: 0.9960
Epoch 165/200
4500/4500 [==============================] - 2s 372us/sample - loss: 0.0046 - accuracy: 1.0000 - val_loss: 0.0187 - val_accuracy: 0.9960
Epoch 166/200
4500/4500 [==============================] - 1s 332us/sample - loss: 0.0046 - accuracy: 1.0000 - val_loss: 0.0183 - val_accuracy: 0.9960
Epoch 167/200
4500/4500 [==============================] - 1s 333us/sample - loss: 0.0044 - accuracy: 1.0000 - val_loss: 0.0198 - val_accuracy: 0.9947
Epoch 168/200
4500/4500 [==============================] - 1s 333us/sample - loss: 0.0047 - accuracy: 1.0000 - val_loss: 0.0177 - val_accuracy: 0.9960
Epoch 169/200
4500/4500 [==============================] - 2s 385us/sample - loss: 0.0043 - accuracy: 1.0000 - val_loss: 0.0175 - val_accuracy: 0.9960
Epoch 170/200
4500/4500 [==============================] - 2s 390us/sample - loss: 0.0041 - accuracy: 1.0000 - val_loss: 0.0180 - val_accuracy: 0.9960
Epoch 171/200
4500/4500 [==============================] - 2s 377us/sample - loss: 0.0040 - accuracy: 1.0000 - val_loss: 0.0178 - val_accuracy: 0.9960
Epoch 172/200
4500/4500 [==============================] - 2s 362us/sample - loss: 0.0039 - accuracy: 1.0000 - val_loss: 0.0180 - val_accuracy: 0.9960
Epoch 173/200
4500/4500 [==============================] - 1s 282us/sample - loss: 0.0039 - accuracy: 1.0000 - val_loss: 0.0177 - val_accuracy: 0.9960
Epoch 174/200
4500/4500 [==============================] - 2s 362us/sample - loss: 0.0038 - accuracy: 1.0000 - val_loss: 0.0180 - val_accuracy: 0.9953
Epoch 175/200
4500/4500 [==============================] - 1s 277us/sample - loss: 0.0037 - accuracy: 1.0000 - val_loss: 0.0170 - val_accuracy: 0.9960
Epoch 176/200
4500/4500 [==============================] - 2s 348us/sample - loss: 0.0036 - accuracy: 1.0000 - val_loss: 0.0173 - val_accuracy: 0.9960
Epoch 177/200
4500/4500 [==============================] - 2s 419us/sample - loss: 0.0035 - accuracy: 1.0000 - val_loss: 0.0168 - val_accuracy: 0.9960
Epoch 178/200
4500/4500 [==============================] - 2s 392us/sample - loss: 0.0034 - accuracy: 1.0000 - val_loss: 0.0168 - val_accuracy: 0.9960
Epoch 179/200
4500/4500 [==============================] - 1s 272us/sample - loss: 0.0032 - accuracy: 1.0000 - val_loss: 0.0170 - val_accuracy: 0.9960
Epoch 180/200
4500/4500 [==============================] - 2s 351us/sample - loss: 0.0033 - accuracy: 1.0000 - val_loss: 0.0164 - val_accuracy: 0.9960
Epoch 181/200
4500/4500 [==============================] - 1s 308us/sample - loss: 0.0031 - accuracy: 1.0000 - val_loss: 0.0164 - val_accuracy: 0.9960
Epoch 182/200
4500/4500 [==============================] - 1s 295us/sample - loss: 0.0031 - accuracy: 1.0000 - val_loss: 0.0164 - val_accuracy: 0.9960
Epoch 183/200
4500/4500 [==============================] - 2s 372us/sample - loss: 0.0030 - accuracy: 1.0000 - val_loss: 0.0159 - val_accuracy: 0.9960
Epoch 184/200
4500/4500 [==============================] - 2s 351us/sample - loss: 0.0029 - accuracy: 1.0000 - val_loss: 0.0162 - val_accuracy: 0.9960
Epoch 185/200
4500/4500 [==============================] - 1s 309us/sample - loss: 0.0029 - accuracy: 1.0000 - val_loss: 0.0160 - val_accuracy: 0.9960
Epoch 186/200
4500/4500 [==============================] - 1s 307us/sample - loss: 0.0028 - accuracy: 1.0000 - val_loss: 0.0163 - val_accuracy: 0.9953
Epoch 187/200
4500/4500 [==============================] - 1s 265us/sample - loss: 0.0028 - accuracy: 1.0000 - val_loss: 0.0162 - val_accuracy: 0.9960
Epoch 188/200
4500/4500 [==============================] - 1s 297us/sample - loss: 0.0028 - accuracy: 1.0000 - val_loss: 0.0161 - val_accuracy: 0.9960
Epoch 189/200
4500/4500 [==============================] - 1s 262us/sample - loss: 0.0027 - accuracy: 1.0000 - val_loss: 0.0158 - val_accuracy: 0.9960
Epoch 190/200
4500/4500 [==============================] - 1s 266us/sample - loss: 0.0025 - accuracy: 1.0000 - val_loss: 0.0160 - val_accuracy: 0.9960
Epoch 191/200
4500/4500 [==============================] - 1s 270us/sample - loss: 0.0026 - accuracy: 1.0000 - val_loss: 0.0156 - val_accuracy: 0.9960
Epoch 192/200
4500/4500 [==============================] - 1s 272us/sample - loss: 0.4587 - accuracy: 0.9436 - val_loss: 3.0601 - val_accuracy: 0.5467
Epoch 193/200
4500/4500 [==============================] - 1s 266us/sample - loss: 0.9706 - accuracy: 0.7185 - val_loss: 0.2870 - val_accuracy: 0.9000
Epoch 194/200
4500/4500 [==============================] - 1s 273us/sample - loss: 0.1387 - accuracy: 0.9677 - val_loss: 0.0814 - val_accuracy: 0.9860
Epoch 195/200
4500/4500 [==============================] - 1s 273us/sample - loss: 0.0551 - accuracy: 0.9939 - val_loss: 0.0473 - val_accuracy: 0.9947
Epoch 196/200
4500/4500 [==============================] - 1s 282us/sample - loss: 0.0279 - accuracy: 0.9995 - val_loss: 0.0353 - val_accuracy: 0.9960
Epoch 197/200
4500/4500 [==============================] - 1s 283us/sample - loss: 0.0206 - accuracy: 0.9998 - val_loss: 0.0321 - val_accuracy: 0.9960
Epoch 198/200
4500/4500 [==============================] - 1s 274us/sample - loss: 0.0174 - accuracy: 0.9999 - val_loss: 0.0298 - val_accuracy: 0.9960
Epoch 199/200
4500/4500 [==============================] - 1s 270us/sample - loss: 0.0151 - accuracy: 0.9999 - val_loss: 0.0276 - val_accuracy: 0.9960
Epoch 200/200
4500/4500 [==============================] - 1s 306us/sample - loss: 0.0133 - accuracy: 0.9999 - val_loss: 0.0264 - val_accuracy: 0.9960
In [32]:
# Renders the charts for training accuracy and loss.
def render_training_history(training_history):
    loss = training_history.history['loss']
    val_loss = training_history.history['val_loss']

    accuracy = training_history.history['accuracy']
    val_accuracy = training_history.history['val_accuracy']

    plt.figure(figsize=(14, 4))

    plt.subplot(1, 2, 1)
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(loss, label='Training set')
    plt.plot(val_loss, label='Test set', linestyle='--')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.subplot(1, 2, 2)
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.plot(accuracy, label='Training set')
    plt.plot(val_accuracy, label='Test set', linestyle='--')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.show()
In [33]:
render_training_history(history)

Test a dataset

In [34]:
x_test, y_test = generate_dataset(dataset_size, sequence_length, max_num, vocabulary)

print('x_test:\n', x_test[:1])
print()
print('y_test:\n', y_test[:1])
x_test:
 [[[0 0 0 0 0 0 0 1 0 0 0 0]
  [0 0 0 0 0 0 1 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 1 0]
  [0 0 0 0 0 0 0 0 1 0 0 0]
  [0 0 0 0 1 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 0 1]
  [0 0 0 0 0 0 0 0 0 0 0 1]]]

y_test:
 [[[0 1 0 0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 1 0 0 0 0 0]
  [1 0 0 0 0 0 0 0 0 0 0 0]]]
In [35]:
predictions = model.predict(x_test)

print('predictions.shape: ', predictions.shape)
print()
print('predictions[0]:\n', predictions[0])
print()
print('predictions[1]:\n', predictions[1])
predictions.shape:  (5000, 3, 12)

predictions[0]:
 [[1.56478808e-09 9.99993443e-01 1.13263614e-10 2.43536555e-13
  7.35255086e-12 1.87883997e-11 2.41102582e-10 4.59983891e-11
  1.34911238e-08 6.54733321e-06 8.69514460e-11 4.95329960e-13]
 [4.08798002e-07 2.92953946e-08 1.33197474e-07 7.34924754e-07
  1.59472929e-05 1.65897664e-02 9.81849730e-01 1.47053797e-03
  6.89475419e-05 3.72986460e-06 1.31091102e-08 1.40613343e-09]
 [9.87638593e-01 4.40192549e-03 1.97818463e-06 6.63003252e-09
  4.55207649e-09 9.64273106e-11 2.73215264e-08 1.80779409e-07
  1.56410679e-05 7.94167537e-03 1.62981628e-08 4.15896695e-09]]

predictions[1]:
 [[2.15935381e-09 9.99992371e-01 5.40815170e-09 6.49596765e-11
  9.35228422e-11 9.23694662e-11 2.05900808e-09 8.82467655e-10
  5.42079519e-08 7.62533909e-06 2.06515055e-10 2.99114275e-11]
 [3.79572568e-10 6.81736542e-11 2.66999258e-08 2.08649180e-06
  1.04209560e-03 9.93228376e-01 5.15519409e-03 5.40546724e-04
  3.17595914e-05 1.57793056e-09 1.10365084e-09 2.60308664e-09]
 [3.24377629e-06 9.20376220e-10 1.31483295e-08 9.00326036e-09
  7.24035954e-06 3.99490673e-05 3.25242523e-03 9.86836672e-01
  9.84544307e-03 3.07867026e-06 5.86528941e-08 1.19321221e-05]]
In [36]:
x_encoded = [decode(example, vocabulary) for example in x_test]
y_expected = [decode(label, vocabulary) for label in y_test]
y_predicted = [decode(prediction, vocabulary) for prediction in predictions]

explore_num = 40
for example, label, prediction in list(zip(x_encoded, y_expected, y_predicted))[:explore_num]:
    checkmark = '✓' if label == prediction else ''
    print('{} = {} [predict: {}] {}'.format(example, label, prediction, checkmark))
76+84   = 160 [predict: 160] ✓
58+99   = 157 [predict: 157] ✓
36+84   = 120 [predict: 120] ✓
62+45   = 107 [predict: 107] ✓
45+66   = 111 [predict: 111] ✓
9+93    = 102 [predict: 102] ✓
62+39   = 101 [predict: 101] ✓
67+11   = 78  [predict: 78 ] ✓
61+15   = 76  [predict: 76 ] ✓
89+40   = 129 [predict: 129] ✓
81+21   = 102 [predict: 102] ✓
13+50   = 63  [predict: 63 ] ✓
48+7    = 55  [predict: 55 ] ✓
47+37   = 84  [predict: 84 ] ✓
9+29    = 38  [predict: 38 ] ✓
55+52   = 107 [predict: 107] ✓
35+75   = 110 [predict: 110] ✓
64+21   = 85  [predict: 85 ] ✓
31+14   = 45  [predict: 45 ] ✓
11+89   = 100 [predict: 100] ✓
18+66   = 84  [predict: 84 ] ✓
31+51   = 82  [predict: 82 ] ✓
9+51    = 60  [predict: 60 ] ✓
37+3    = 40  [predict: 40 ] ✓
92+89   = 181 [predict: 181] ✓
7+39    = 46  [predict: 46 ] ✓
14+73   = 87  [predict: 87 ] ✓
15+47   = 62  [predict: 62 ] ✓
34+23   = 57  [predict: 57 ] ✓
17+26   = 43  [predict: 43 ] ✓
83+10   = 93  [predict: 93 ] ✓
66+96   = 162 [predict: 162] ✓
56+71   = 127 [predict: 127] ✓
28+40   = 68  [predict: 68 ] ✓
5+77    = 82  [predict: 82 ] ✓
61+76   = 137 [predict: 137] ✓
69+47   = 116 [predict: 116] ✓
59+81   = 140 [predict: 140] ✓
91+88   = 179 [predict: 179] ✓
18+72   = 90  [predict: 90 ] ✓

Debugging the model with TensorBoard

TensorBoard is a tool for providing the measurements and visualizations needed during the machine learning workflow. It enables tracking experiment metrics like loss and accuracy, visualizing the model graph, projecting embeddings to a lower dimensional space, and much more.

In [37]:
%tensorboard --logdir .logs/fit

Save a model

In [47]:
model_name = 'numbers_summation_rnn.h5'
model.save(model_name, save_format='h5')

Converting the model to web-format

To use this model on the web we need to convert it into the format that will be understandable by tensorflowjs. To do so we may use tfjs-converter as following:

tensorflowjs_converter --input_format keras \
  ./experiments/numbers_summation_rnn/numbers_summation_rnn.h5 \
  ./demos/public/models/numbers_summation_rnn

You may find this experiment in the Demo app and play around with it right in you browser to see how the model performs in real life.