Chapter 15 – Processing Sequences Using RNNs and CNNs

This notebook contains all the sample code in chapter 15.

Setup

First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0.

In [102]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
    IS_COLAB = True
except Exception:
    IS_COLAB = False

# TensorFlow ≥2.0 is required
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"

if not tf.config.list_physical_devices('GPU'):
    print("No GPU was detected. LSTMs and CNNs can be very slow without a GPU.")
    if IS_COLAB:
        print("Go to Runtime > Change runtime and select a GPU hardware accelerator.")

# Common imports
import numpy as np
import os
from pathlib import Path

# to make this notebook's output stable across runs
np.random.seed(42)
tf.random.set_seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "rnn"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

Basic RNNs

Generate the Dataset

In [103]:
def generate_time_series(batch_size, n_steps):
    freq1, freq2, offsets1, offsets2 = np.random.rand(4, batch_size, 1)
    time = np.linspace(0, 1, n_steps)
    series = 0.5 * np.sin((time - offsets1) * (freq1 * 10 + 10))  #   wave 1
    series += 0.2 * np.sin((time - offsets2) * (freq2 * 20 + 20)) # + wave 2
    series += 0.1 * (np.random.rand(batch_size, n_steps) - 0.5)   # + noise
    return series[..., np.newaxis].astype(np.float32)
In [104]:
np.random.seed(42)

n_steps = 50
series = generate_time_series(10000, n_steps + 1)
X_train, y_train = series[:7000, :n_steps], series[:7000, -1]
X_valid, y_valid = series[7000:9000, :n_steps], series[7000:9000, -1]
X_test, y_test = series[9000:, :n_steps], series[9000:, -1]
In [105]:
X_train.shape, y_train.shape
Out[105]:
((7000, 50, 1), (7000, 1))
In [106]:
def plot_series(series, y=None, y_pred=None, x_label="$t$", y_label="$x(t)$"):
    plt.plot(series, ".-")
    if y is not None:
        plt.plot(n_steps, y, "bx", markersize=10)
    if y_pred is not None:
        plt.plot(n_steps, y_pred, "ro")
    plt.grid(True)
    if x_label:
        plt.xlabel(x_label, fontsize=16)
    if y_label:
        plt.ylabel(y_label, fontsize=16, rotation=0)
    plt.hlines(0, 0, 100, linewidth=1)
    plt.axis([0, n_steps + 1, -1, 1])

fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))
for col in range(3):
    plt.sca(axes[col])
    plot_series(X_valid[col, :, 0], y_valid[col, 0],
                y_label=("$x(t)$" if col==0 else None))
save_fig("time_series_plot")
plt.show()
Saving figure time_series_plot

Computing Some Baselines

Naive predictions (just predict the last observed value):

In [107]:
y_pred = X_valid[:, -1]
np.mean(keras.losses.mean_squared_error(y_valid, y_pred))
Out[107]:
0.020211367
In [108]:
plot_series(X_valid[0, :, 0], y_valid[0, 0], y_pred[0, 0])
plt.show()

Linear predictions:

In [109]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[50, 1]),
    keras.layers.Dense(1)
])

model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, y_train, epochs=20,
                    validation_data=(X_valid, y_valid))
Epoch 1/20
219/219 [==============================] - 1s 3ms/step - loss: 0.1001 - val_loss: 0.0545
Epoch 2/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0379 - val_loss: 0.0266
Epoch 3/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0202 - val_loss: 0.0157
Epoch 4/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0131 - val_loss: 0.0116
Epoch 5/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0103 - val_loss: 0.0098
Epoch 6/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0089 - val_loss: 0.0087
Epoch 7/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0080 - val_loss: 0.0079
Epoch 8/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0073 - val_loss: 0.0071
Epoch 9/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0066 - val_loss: 0.0066
Epoch 10/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0061 - val_loss: 0.0062
Epoch 11/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0057 - val_loss: 0.0057
Epoch 12/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0054 - val_loss: 0.0055
Epoch 13/20
219/219 [==============================] - 0s 2ms/step - loss: 0.0052 - val_loss: 0.0052
Epoch 14/20
219/219 [==============================] - 0s 2ms/step - loss: 0.0049 - val_loss: 0.0049
Epoch 15/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0048 - val_loss: 0.0048
Epoch 16/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0046 - val_loss: 0.0048
Epoch 17/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0045 - val_loss: 0.0045
Epoch 18/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0044 - val_loss: 0.0044
Epoch 19/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0043 - val_loss: 0.0043
Epoch 20/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0042 - val_loss: 0.0042
In [110]:
model.evaluate(X_valid, y_valid)
63/63 [==============================] - 0s 2ms/step - loss: 0.0042
Out[110]:
0.004168085753917694
In [111]:
def plot_learning_curves(loss, val_loss):
    plt.plot(np.arange(len(loss)) + 0.5, loss, "b.-", label="Training loss")
    plt.plot(np.arange(len(val_loss)) + 1, val_loss, "r.-", label="Validation loss")
    plt.gca().xaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
    plt.axis([1, 20, 0, 0.05])
    plt.legend(fontsize=14)
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.grid(True)

plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [112]:
y_pred = model.predict(X_valid)
plot_series(X_valid[0, :, 0], y_valid[0, 0], y_pred[0, 0])
plt.show()

Using a Simple RNN

In [113]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(1, input_shape=[None, 1])
])

optimizer = keras.optimizers.Adam(lr=0.005)
model.compile(loss="mse", optimizer=optimizer)
history = model.fit(X_train, y_train, epochs=20,
                    validation_data=(X_valid, y_valid))
Epoch 1/20
219/219 [==============================] - 8s 34ms/step - loss: 0.0967 - val_loss: 0.0489
Epoch 2/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0369 - val_loss: 0.0296
Epoch 3/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0253 - val_loss: 0.0218
Epoch 4/20
219/219 [==============================] - 8s 34ms/step - loss: 0.0198 - val_loss: 0.0177
Epoch 5/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0166 - val_loss: 0.0151
Epoch 6/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0146 - val_loss: 0.0134
Epoch 7/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0132 - val_loss: 0.0123
Epoch 8/20
219/219 [==============================] - 8s 34ms/step - loss: 0.0124 - val_loss: 0.0116
Epoch 9/20
219/219 [==============================] - 8s 35ms/step - loss: 0.0118 - val_loss: 0.0112
Epoch 10/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0116 - val_loss: 0.0110
Epoch 11/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 12/20
219/219 [==============================] - 8s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 13/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 14/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 15/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 16/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 17/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 18/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 19/20
219/219 [==============================] - 7s 34ms/step - loss: 0.0114 - val_loss: 0.0109
Epoch 20/20
219/219 [==============================] - 8s 34ms/step - loss: 0.0114 - val_loss: 0.0109
In [114]:
model.evaluate(X_valid, y_valid)
63/63 [==============================] - 0s 7ms/step - loss: 0.0109
Out[114]:
0.010881561785936356
In [115]:
plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [116]:
y_pred = model.predict(X_valid)
plot_series(X_valid[0, :, 0], y_valid[0, 0], y_pred[0, 0])
plt.show()

Deep RNNs

In [117]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.SimpleRNN(20, return_sequences=True),
    keras.layers.SimpleRNN(1)
])

model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, y_train, epochs=20,
                    validation_data=(X_valid, y_valid))
Epoch 1/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0492 - val_loss: 0.0090
Epoch 2/20
219/219 [==============================] - 22s 103ms/step - loss: 0.0070 - val_loss: 0.0065
Epoch 3/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0053 - val_loss: 0.0045
Epoch 4/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0045 - val_loss: 0.0040
Epoch 5/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0042 - val_loss: 0.0040
Epoch 6/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0038 - val_loss: 0.0036
Epoch 7/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0038 - val_loss: 0.0040
Epoch 8/20
219/219 [==============================] - 22s 101ms/step - loss: 0.0037 - val_loss: 0.0033
Epoch 9/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0036 - val_loss: 0.0032
Epoch 10/20
219/219 [==============================] - 22s 99ms/step - loss: 0.0035 - val_loss: 0.0031
Epoch 11/20
219/219 [==============================] - 22s 100ms/step - loss: 0.0034 - val_loss: 0.0030
Epoch 12/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0033 - val_loss: 0.0031
Epoch 13/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0034 - val_loss: 0.0031
Epoch 14/20
219/219 [==============================] - 22s 100ms/step - loss: 0.0033 - val_loss: 0.0032
Epoch 15/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0034 - val_loss: 0.0033
Epoch 16/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0035 - val_loss: 0.0030
Epoch 17/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0033 - val_loss: 0.0029
Epoch 18/20
219/219 [==============================] - 22s 102ms/step - loss: 0.0033 - val_loss: 0.0030
Epoch 19/20
219/219 [==============================] - 22s 100ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 20/20
219/219 [==============================] - 21s 98ms/step - loss: 0.0032 - val_loss: 0.0029
In [118]:
model.evaluate(X_valid, y_valid)
63/63 [==============================] - 1s 18ms/step - loss: 0.0029
Out[118]:
0.0029105604626238346
In [119]:
plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [120]:
y_pred = model.predict(X_valid)
plot_series(X_valid[0, :, 0], y_valid[0, 0], y_pred[0, 0])
plt.show()

Make the second SimpleRNN layer return only the last output:

In [121]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.SimpleRNN(20),
    keras.layers.Dense(1)
])

model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, y_train, epochs=20,
                    validation_data=(X_valid, y_valid))
Epoch 1/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0232 - val_loss: 0.0052
Epoch 2/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0043 - val_loss: 0.0036
Epoch 3/20
219/219 [==============================] - 15s 68ms/step - loss: 0.0035 - val_loss: 0.0031
Epoch 4/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0033 - val_loss: 0.0033
Epoch 5/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0033 - val_loss: 0.0034
Epoch 6/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0031 - val_loss: 0.0029
Epoch 7/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0031 - val_loss: 0.0034
Epoch 8/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0032 - val_loss: 0.0028
Epoch 9/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0031 - val_loss: 0.0028
Epoch 10/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0030 - val_loss: 0.0029
Epoch 11/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0029 - val_loss: 0.0027
Epoch 12/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0030 - val_loss: 0.0031
Epoch 13/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0030 - val_loss: 0.0031
Epoch 14/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0030 - val_loss: 0.0030
Epoch 15/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0030 - val_loss: 0.0030
Epoch 16/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 17/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0029 - val_loss: 0.0028
Epoch 18/20
219/219 [==============================] - 15s 68ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 19/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0029 - val_loss: 0.0028
Epoch 20/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0029 - val_loss: 0.0026
In [122]:
model.evaluate(X_valid, y_valid)
63/63 [==============================] - 1s 15ms/step - loss: 0.0026
Out[122]:
0.002623623935505748
In [123]:
plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [124]:
y_pred = model.predict(X_valid)
plot_series(X_valid[0, :, 0], y_valid[0, 0], y_pred[0, 0])
plt.show()

Forecasting Several Steps Ahead

In [125]:
np.random.seed(43) # not 42, as it would give the first series in the train set

series = generate_time_series(1, n_steps + 10)
X_new, Y_new = series[:, :n_steps], series[:, n_steps:]
X = X_new
for step_ahead in range(10):
    y_pred_one = model.predict(X[:, step_ahead:])[:, np.newaxis, :]
    X = np.concatenate([X, y_pred_one], axis=1)

Y_pred = X[:, n_steps:]
In [126]:
Y_pred.shape
Out[126]:
(1, 10, 1)
In [127]:
def plot_multiple_forecasts(X, Y, Y_pred):
    n_steps = X.shape[1]
    ahead = Y.shape[1]
    plot_series(X[0, :, 0])
    plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], "ro-", label="Actual")
    plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], "bx-", label="Forecast", markersize=10)
    plt.axis([0, n_steps + ahead, -1, 1])
    plt.legend(fontsize=14)

plot_multiple_forecasts(X_new, Y_new, Y_pred)
save_fig("forecast_ahead_plot")
plt.show()
Saving figure forecast_ahead_plot

Now let's use this model to predict the next 10 values. We first need to regenerate the sequences with 9 more time steps.

In [128]:
np.random.seed(42)

n_steps = 50
series = generate_time_series(10000, n_steps + 10)
X_train, Y_train = series[:7000, :n_steps], series[:7000, -10:, 0]
X_valid, Y_valid = series[7000:9000, :n_steps], series[7000:9000, -10:, 0]
X_test, Y_test = series[9000:, :n_steps], series[9000:, -10:, 0]

Now let's predict the next 10 values one by one:

In [129]:
X = X_valid
for step_ahead in range(10):
    y_pred_one = model.predict(X)[:, np.newaxis, :]
    X = np.concatenate([X, y_pred_one], axis=1)

Y_pred = X[:, n_steps:, 0]
In [130]:
Y_pred.shape
Out[130]:
(2000, 10)
In [131]:
np.mean(keras.metrics.mean_squared_error(Y_valid, Y_pred))
Out[131]:
0.02751083

Let's compare this performance with some baselines: naive predictions and a simple linear model:

In [132]:
Y_naive_pred = Y_valid[:, -1:]
np.mean(keras.metrics.mean_squared_error(Y_valid, Y_naive_pred))
Out[132]:
0.22278848
In [133]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[50, 1]),
    keras.layers.Dense(10)
])

model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 1s 3ms/step - loss: 0.1343 - val_loss: 0.0606
Epoch 2/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0496 - val_loss: 0.0425
Epoch 3/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0385 - val_loss: 0.0353
Epoch 4/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0331 - val_loss: 0.0311
Epoch 5/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0298 - val_loss: 0.0283
Epoch 6/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0273 - val_loss: 0.0264
Epoch 7/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0256 - val_loss: 0.0249
Epoch 8/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0244 - val_loss: 0.0237
Epoch 9/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0234 - val_loss: 0.0229
Epoch 10/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0227 - val_loss: 0.0222
Epoch 11/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0220 - val_loss: 0.0216
Epoch 12/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0215 - val_loss: 0.0212
Epoch 13/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0210 - val_loss: 0.0208
Epoch 14/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0207 - val_loss: 0.0207
Epoch 15/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0203 - val_loss: 0.0202
Epoch 16/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0200 - val_loss: 0.0199
Epoch 17/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0197 - val_loss: 0.0195
Epoch 18/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0193 - val_loss: 0.0192
Epoch 19/20
219/219 [==============================] - 1s 2ms/step - loss: 0.0191 - val_loss: 0.0189
Epoch 20/20
219/219 [==============================] - 1s 3ms/step - loss: 0.0188 - val_loss: 0.0187

Now let's create an RNN that predicts all 10 next values at once:

In [134]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.SimpleRNN(20),
    keras.layers.Dense(10)
])

model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0669 - val_loss: 0.0317
Epoch 2/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0265 - val_loss: 0.0200
Epoch 3/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0183 - val_loss: 0.0160
Epoch 4/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0155 - val_loss: 0.0144
Epoch 5/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0139 - val_loss: 0.0118
Epoch 6/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0128 - val_loss: 0.0112
Epoch 7/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0122 - val_loss: 0.0110
Epoch 8/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0115 - val_loss: 0.0103
Epoch 9/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0111 - val_loss: 0.0112
Epoch 10/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0110 - val_loss: 0.0100
Epoch 11/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0108 - val_loss: 0.0103
Epoch 12/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0102 - val_loss: 0.0096
Epoch 13/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0104 - val_loss: 0.0100
Epoch 14/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0098 - val_loss: 0.0103
Epoch 15/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0095 - val_loss: 0.0107
Epoch 16/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0092 - val_loss: 0.0089
Epoch 17/20
219/219 [==============================] - 15s 68ms/step - loss: 0.0094 - val_loss: 0.0111
Epoch 18/20
219/219 [==============================] - 15s 68ms/step - loss: 0.0095 - val_loss: 0.0094
Epoch 19/20
219/219 [==============================] - 15s 67ms/step - loss: 0.0093 - val_loss: 0.0083
Epoch 20/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0094 - val_loss: 0.0085
In [135]:
np.random.seed(43)

series = generate_time_series(1, 50 + 10)
X_new, Y_new = series[:, :50, :], series[:, -10:, :]
Y_pred = model.predict(X_new)[..., np.newaxis]
In [136]:
plot_multiple_forecasts(X_new, Y_new, Y_pred)
plt.show()

Now let's create an RNN that predicts the next 10 steps at each time step. That is, instead of just forecasting time steps 50 to 59 based on time steps 0 to 49, it will forecast time steps 1 to 10 at time step 0, then time steps 2 to 11 at time step 1, and so on, and finally it will forecast time steps 50 to 59 at the last time step. Notice that the model is causal: when it makes predictions at any time step, it can only see past time steps.

In [137]:
np.random.seed(42)

n_steps = 50
series = generate_time_series(10000, n_steps + 10)
X_train = series[:7000, :n_steps]
X_valid = series[7000:9000, :n_steps]
X_test = series[9000:, :n_steps]
Y = np.empty((10000, n_steps, 10))
for step_ahead in range(1, 10 + 1):
    Y[..., step_ahead - 1] = series[..., step_ahead:step_ahead + n_steps, 0]
Y_train = Y[:7000]
Y_valid = Y[7000:9000]
Y_test = Y[9000:]
In [138]:
X_train.shape, Y_train.shape
Out[138]:
((7000, 50, 1), (7000, 50, 10))
In [139]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.SimpleRNN(20, return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

def last_time_step_mse(Y_true, Y_pred):
    return keras.metrics.mean_squared_error(Y_true[:, -1], Y_pred[:, -1])

model.compile(loss="mse", optimizer=keras.optimizers.Adam(lr=0.01), metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0508 - last_time_step_mse: 0.0400 - val_loss: 0.0429 - val_last_time_step_mse: 0.0324
Epoch 2/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0395 - last_time_step_mse: 0.0283 - val_loss: 0.0351 - val_last_time_step_mse: 0.0243
Epoch 3/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0324 - last_time_step_mse: 0.0213 - val_loss: 0.0301 - val_last_time_step_mse: 0.0185
Epoch 4/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0294 - last_time_step_mse: 0.0182 - val_loss: 0.0270 - val_last_time_step_mse: 0.0152
Epoch 5/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0271 - last_time_step_mse: 0.0152 - val_loss: 0.0256 - val_last_time_step_mse: 0.0144
Epoch 6/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0251 - last_time_step_mse: 0.0124 - val_loss: 0.0231 - val_last_time_step_mse: 0.0095
Epoch 7/20
219/219 [==============================] - 15s 71ms/step - loss: 0.0229 - last_time_step_mse: 0.0100 - val_loss: 0.0219 - val_last_time_step_mse: 0.0088
Epoch 8/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0216 - last_time_step_mse: 0.0086 - val_loss: 0.0219 - val_last_time_step_mse: 0.0093
Epoch 9/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0212 - last_time_step_mse: 0.0085 - val_loss: 0.0207 - val_last_time_step_mse: 0.0084
Epoch 10/20
219/219 [==============================] - 15s 69ms/step - loss: 0.0211 - last_time_step_mse: 0.0086 - val_loss: 0.0209 - val_last_time_step_mse: 0.0090
Epoch 11/20
219/219 [==============================] - 15s 71ms/step - loss: 0.0205 - last_time_step_mse: 0.0081 - val_loss: 0.0198 - val_last_time_step_mse: 0.0075
Epoch 12/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0202 - last_time_step_mse: 0.0079 - val_loss: 0.0191 - val_last_time_step_mse: 0.0067
Epoch 13/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0198 - last_time_step_mse: 0.0076 - val_loss: 0.0213 - val_last_time_step_mse: 0.0089
Epoch 14/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0195 - last_time_step_mse: 0.0073 - val_loss: 0.0192 - val_last_time_step_mse: 0.0069
Epoch 15/20
219/219 [==============================] - 15s 67ms/step - loss: 0.0194 - last_time_step_mse: 0.0073 - val_loss: 0.0187 - val_last_time_step_mse: 0.0075
Epoch 16/20
219/219 [==============================] - 15s 71ms/step - loss: 0.0188 - last_time_step_mse: 0.0067 - val_loss: 0.0195 - val_last_time_step_mse: 0.0092
Epoch 17/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0187 - last_time_step_mse: 0.0067 - val_loss: 0.0185 - val_last_time_step_mse: 0.0072
Epoch 18/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0187 - last_time_step_mse: 0.0070 - val_loss: 0.0179 - val_last_time_step_mse: 0.0064
Epoch 19/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0185 - last_time_step_mse: 0.0069 - val_loss: 0.0173 - val_last_time_step_mse: 0.0057
Epoch 20/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0183 - last_time_step_mse: 0.0069 - val_loss: 0.0203 - val_last_time_step_mse: 0.0097
In [140]:
np.random.seed(43)

series = generate_time_series(1, 50 + 10)
X_new, Y_new = series[:, :50, :], series[:, 50:, :]
Y_pred = model.predict(X_new)[:, -1][..., np.newaxis]
In [141]:
plot_multiple_forecasts(X_new, Y_new, Y_pred)
plt.show()

Deep RNN with Batch Norm

In [142]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.SimpleRNN(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.BatchNormalization(),
    keras.layers.SimpleRNN(20, return_sequences=True),
    keras.layers.BatchNormalization(),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 16s 73ms/step - loss: 0.1929 - last_time_step_mse: 0.1902 - val_loss: 0.0877 - val_last_time_step_mse: 0.0832
Epoch 2/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0537 - last_time_step_mse: 0.0449 - val_loss: 0.0549 - val_last_time_step_mse: 0.0462
Epoch 3/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0471 - last_time_step_mse: 0.0375 - val_loss: 0.0451 - val_last_time_step_mse: 0.0358
Epoch 4/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0437 - last_time_step_mse: 0.0337 - val_loss: 0.0418 - val_last_time_step_mse: 0.0314
Epoch 5/20
219/219 [==============================] - 15s 71ms/step - loss: 0.0409 - last_time_step_mse: 0.0306 - val_loss: 0.0391 - val_last_time_step_mse: 0.0287
Epoch 6/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0385 - last_time_step_mse: 0.0275 - val_loss: 0.0379 - val_last_time_step_mse: 0.0273
Epoch 7/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0366 - last_time_step_mse: 0.0254 - val_loss: 0.0367 - val_last_time_step_mse: 0.0248
Epoch 8/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0349 - last_time_step_mse: 0.0235 - val_loss: 0.0363 - val_last_time_step_mse: 0.0249
Epoch 9/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0338 - last_time_step_mse: 0.0221 - val_loss: 0.0332 - val_last_time_step_mse: 0.0208
Epoch 10/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0329 - last_time_step_mse: 0.0214 - val_loss: 0.0335 - val_last_time_step_mse: 0.0214
Epoch 11/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0322 - last_time_step_mse: 0.0206 - val_loss: 0.0323 - val_last_time_step_mse: 0.0203
Epoch 12/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0316 - last_time_step_mse: 0.0198 - val_loss: 0.0333 - val_last_time_step_mse: 0.0210
Epoch 13/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0310 - last_time_step_mse: 0.0191 - val_loss: 0.0310 - val_last_time_step_mse: 0.0187
Epoch 14/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0305 - last_time_step_mse: 0.0186 - val_loss: 0.0310 - val_last_time_step_mse: 0.0189
Epoch 15/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0302 - last_time_step_mse: 0.0182 - val_loss: 0.0298 - val_last_time_step_mse: 0.0178
Epoch 16/20
219/219 [==============================] - 15s 71ms/step - loss: 0.0296 - last_time_step_mse: 0.0176 - val_loss: 0.0293 - val_last_time_step_mse: 0.0174
Epoch 17/20
219/219 [==============================] - 15s 70ms/step - loss: 0.0293 - last_time_step_mse: 0.0172 - val_loss: 0.0315 - val_last_time_step_mse: 0.0200
Epoch 18/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0289 - last_time_step_mse: 0.0168 - val_loss: 0.0295 - val_last_time_step_mse: 0.0174
Epoch 19/20
219/219 [==============================] - 16s 71ms/step - loss: 0.0286 - last_time_step_mse: 0.0168 - val_loss: 0.0290 - val_last_time_step_mse: 0.0163
Epoch 20/20
219/219 [==============================] - 16s 72ms/step - loss: 0.0281 - last_time_step_mse: 0.0161 - val_loss: 0.0288 - val_last_time_step_mse: 0.0164

Deep RNNs with Layer Norm

In [143]:
from tensorflow.keras.layers import LayerNormalization
In [144]:
class LNSimpleRNNCell(keras.layers.Layer):
    def __init__(self, units, activation="tanh", **kwargs):
        super().__init__(**kwargs)
        self.state_size = units
        self.output_size = units
        self.simple_rnn_cell = keras.layers.SimpleRNNCell(units,
                                                          activation=None)
        self.layer_norm = LayerNormalization()
        self.activation = keras.activations.get(activation)
    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        if inputs is not None:
            batch_size = tf.shape(inputs)[0]
            dtype = inputs.dtype
        return [tf.zeros([batch_size, self.state_size], dtype=dtype)]
    def call(self, inputs, states):
        outputs, new_states = self.simple_rnn_cell(inputs, states)
        norm_outputs = self.activation(self.layer_norm(outputs))
        return norm_outputs, [norm_outputs]
In [145]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.RNN(LNSimpleRNNCell(20), return_sequences=True,
                     input_shape=[None, 1]),
    keras.layers.RNN(LNSimpleRNNCell(20), return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 37s 171ms/step - loss: 0.1592 - last_time_step_mse: 0.1557 - val_loss: 0.0746 - val_last_time_step_mse: 0.0684
Epoch 2/20
219/219 [==============================] - 36s 166ms/step - loss: 0.0663 - last_time_step_mse: 0.0579 - val_loss: 0.0591 - val_last_time_step_mse: 0.0485
Epoch 3/20
219/219 [==============================] - 36s 163ms/step - loss: 0.0551 - last_time_step_mse: 0.0441 - val_loss: 0.0512 - val_last_time_step_mse: 0.0381
Epoch 4/20
219/219 [==============================] - 36s 166ms/step - loss: 0.0485 - last_time_step_mse: 0.0359 - val_loss: 0.0454 - val_last_time_step_mse: 0.0327
Epoch 5/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0437 - last_time_step_mse: 0.0310 - val_loss: 0.0418 - val_last_time_step_mse: 0.0284
Epoch 6/20
219/219 [==============================] - 36s 165ms/step - loss: 0.0399 - last_time_step_mse: 0.0268 - val_loss: 0.0376 - val_last_time_step_mse: 0.0229
Epoch 7/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0363 - last_time_step_mse: 0.0226 - val_loss: 0.0345 - val_last_time_step_mse: 0.0206
Epoch 8/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0333 - last_time_step_mse: 0.0199 - val_loss: 0.0326 - val_last_time_step_mse: 0.0192
Epoch 9/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0314 - last_time_step_mse: 0.0181 - val_loss: 0.0304 - val_last_time_step_mse: 0.0174
Epoch 10/20
219/219 [==============================] - 37s 167ms/step - loss: 0.0300 - last_time_step_mse: 0.0170 - val_loss: 0.0290 - val_last_time_step_mse: 0.0154
Epoch 11/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0290 - last_time_step_mse: 0.0162 - val_loss: 0.0284 - val_last_time_step_mse: 0.0153
Epoch 12/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0285 - last_time_step_mse: 0.0159 - val_loss: 0.0278 - val_last_time_step_mse: 0.0149
Epoch 13/20
219/219 [==============================] - 37s 167ms/step - loss: 0.0276 - last_time_step_mse: 0.0150 - val_loss: 0.0269 - val_last_time_step_mse: 0.0139
Epoch 14/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0272 - last_time_step_mse: 0.0146 - val_loss: 0.0265 - val_last_time_step_mse: 0.0140
Epoch 15/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0268 - last_time_step_mse: 0.0143 - val_loss: 0.0261 - val_last_time_step_mse: 0.0133
Epoch 16/20
219/219 [==============================] - 37s 168ms/step - loss: 0.0265 - last_time_step_mse: 0.0141 - val_loss: 0.0257 - val_last_time_step_mse: 0.0129
Epoch 17/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0260 - last_time_step_mse: 0.0136 - val_loss: 0.0256 - val_last_time_step_mse: 0.0132
Epoch 18/20
219/219 [==============================] - 37s 170ms/step - loss: 0.0256 - last_time_step_mse: 0.0133 - val_loss: 0.0251 - val_last_time_step_mse: 0.0128
Epoch 19/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0255 - last_time_step_mse: 0.0133 - val_loss: 0.0248 - val_last_time_step_mse: 0.0120
Epoch 20/20
219/219 [==============================] - 37s 169ms/step - loss: 0.0250 - last_time_step_mse: 0.0128 - val_loss: 0.0246 - val_last_time_step_mse: 0.0122

Creating a Custom RNN Class

In [146]:
class MyRNN(keras.layers.Layer):
    def __init__(self, cell, return_sequences=False, **kwargs):
        super().__init__(**kwargs)
        self.cell = cell
        self.return_sequences = return_sequences
        self.get_initial_state = getattr(
            self.cell, "get_initial_state", self.fallback_initial_state)
    def fallback_initial_state(self, inputs):
        return [tf.zeros([self.cell.state_size], dtype=inputs.dtype)]
    @tf.function
    def call(self, inputs):
        states = self.get_initial_state(inputs)
        n_steps = tf.shape(inputs)[1]
        if self.return_sequences:
            sequences = tf.TensorArray(inputs.dtype, size=n_steps)
        outputs = tf.zeros(shape=[n_steps, self.cell.output_size], dtype=inputs.dtype)
        for step in tf.range(n_steps):
            outputs, states = self.cell(inputs[:, step], states)
            if self.return_sequences:
                sequences = sequences.write(step, outputs)
        if self.return_sequences:
            return sequences.stack()
        else:
            return outputs
In [147]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    MyRNN(LNSimpleRNNCell(20), return_sequences=True,
          input_shape=[None, 1]),
    MyRNN(LNSimpleRNNCell(20), return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 40s 182ms/step - loss: 0.2026 - last_time_step_mse: 0.2011 - val_loss: 0.0780 - val_last_time_step_mse: 0.0735
Epoch 2/20
219/219 [==============================] - 40s 181ms/step - loss: 0.0699 - last_time_step_mse: 0.0638 - val_loss: 0.0637 - val_last_time_step_mse: 0.0577
Epoch 3/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0608 - last_time_step_mse: 0.0539 - val_loss: 0.0577 - val_last_time_step_mse: 0.0512
Epoch 4/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0558 - last_time_step_mse: 0.0488 - val_loss: 0.0532 - val_last_time_step_mse: 0.0465
Epoch 5/20
219/219 [==============================] - 40s 181ms/step - loss: 0.0520 - last_time_step_mse: 0.0447 - val_loss: 0.0503 - val_last_time_step_mse: 0.0433
Epoch 6/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0490 - last_time_step_mse: 0.0414 - val_loss: 0.0470 - val_last_time_step_mse: 0.0396
Epoch 7/20
219/219 [==============================] - 40s 180ms/step - loss: 0.0457 - last_time_step_mse: 0.0371 - val_loss: 0.0437 - val_last_time_step_mse: 0.0335
Epoch 8/20
219/219 [==============================] - 39s 179ms/step - loss: 0.0418 - last_time_step_mse: 0.0312 - val_loss: 0.0404 - val_last_time_step_mse: 0.0294
Epoch 9/20
219/219 [==============================] - 39s 179ms/step - loss: 0.0390 - last_time_step_mse: 0.0272 - val_loss: 0.0378 - val_last_time_step_mse: 0.0251
Epoch 10/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0366 - last_time_step_mse: 0.0246 - val_loss: 0.0351 - val_last_time_step_mse: 0.0229
Epoch 11/20
219/219 [==============================] - 39s 179ms/step - loss: 0.0346 - last_time_step_mse: 0.0229 - val_loss: 0.0335 - val_last_time_step_mse: 0.0208
Epoch 12/20
219/219 [==============================] - 39s 179ms/step - loss: 0.0334 - last_time_step_mse: 0.0216 - val_loss: 0.0325 - val_last_time_step_mse: 0.0202
Epoch 13/20
219/219 [==============================] - 39s 178ms/step - loss: 0.0325 - last_time_step_mse: 0.0206 - val_loss: 0.0321 - val_last_time_step_mse: 0.0200
Epoch 14/20
219/219 [==============================] - 40s 180ms/step - loss: 0.0318 - last_time_step_mse: 0.0200 - val_loss: 0.0320 - val_last_time_step_mse: 0.0206
Epoch 15/20
219/219 [==============================] - 39s 178ms/step - loss: 0.0313 - last_time_step_mse: 0.0194 - val_loss: 0.0311 - val_last_time_step_mse: 0.0193
Epoch 16/20
219/219 [==============================] - 39s 178ms/step - loss: 0.0310 - last_time_step_mse: 0.0194 - val_loss: 0.0306 - val_last_time_step_mse: 0.0188
Epoch 17/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0304 - last_time_step_mse: 0.0187 - val_loss: 0.0298 - val_last_time_step_mse: 0.0174
Epoch 18/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0298 - last_time_step_mse: 0.0180 - val_loss: 0.0296 - val_last_time_step_mse: 0.0179
Epoch 19/20
219/219 [==============================] - 40s 180ms/step - loss: 0.0295 - last_time_step_mse: 0.0177 - val_loss: 0.0291 - val_last_time_step_mse: 0.0169
Epoch 20/20
219/219 [==============================] - 39s 180ms/step - loss: 0.0291 - last_time_step_mse: 0.0171 - val_loss: 0.0290 - val_last_time_step_mse: 0.0167

LSTMs

In [148]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.LSTM(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.LSTM(20, return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 4s 20ms/step - loss: 0.0760 - last_time_step_mse: 0.0615 - val_loss: 0.0554 - val_last_time_step_mse: 0.0364
Epoch 2/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0480 - last_time_step_mse: 0.0283 - val_loss: 0.0427 - val_last_time_step_mse: 0.0222
Epoch 3/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0391 - last_time_step_mse: 0.0181 - val_loss: 0.0367 - val_last_time_step_mse: 0.0157
Epoch 4/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0350 - last_time_step_mse: 0.0151 - val_loss: 0.0334 - val_last_time_step_mse: 0.0132
Epoch 5/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0325 - last_time_step_mse: 0.0133 - val_loss: 0.0314 - val_last_time_step_mse: 0.0121
Epoch 6/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0308 - last_time_step_mse: 0.0122 - val_loss: 0.0298 - val_last_time_step_mse: 0.0112
Epoch 7/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0297 - last_time_step_mse: 0.0118 - val_loss: 0.0291 - val_last_time_step_mse: 0.0120
Epoch 8/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0286 - last_time_step_mse: 0.0109 - val_loss: 0.0278 - val_last_time_step_mse: 0.0099
Epoch 9/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0280 - last_time_step_mse: 0.0108 - val_loss: 0.0278 - val_last_time_step_mse: 0.0113
Epoch 10/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0273 - last_time_step_mse: 0.0105 - val_loss: 0.0268 - val_last_time_step_mse: 0.0101
Epoch 11/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0269 - last_time_step_mse: 0.0102 - val_loss: 0.0263 - val_last_time_step_mse: 0.0096
Epoch 12/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0264 - last_time_step_mse: 0.0101 - val_loss: 0.0263 - val_last_time_step_mse: 0.0105
Epoch 13/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0259 - last_time_step_mse: 0.0097 - val_loss: 0.0257 - val_last_time_step_mse: 0.0100
Epoch 14/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0257 - last_time_step_mse: 0.0096 - val_loss: 0.0252 - val_last_time_step_mse: 0.0091
Epoch 15/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0253 - last_time_step_mse: 0.0095 - val_loss: 0.0251 - val_last_time_step_mse: 0.0092
Epoch 16/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0251 - last_time_step_mse: 0.0095 - val_loss: 0.0248 - val_last_time_step_mse: 0.0089
Epoch 17/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0248 - last_time_step_mse: 0.0094 - val_loss: 0.0248 - val_last_time_step_mse: 0.0098
Epoch 18/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0245 - last_time_step_mse: 0.0093 - val_loss: 0.0246 - val_last_time_step_mse: 0.0091
Epoch 19/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0242 - last_time_step_mse: 0.0091 - val_loss: 0.0238 - val_last_time_step_mse: 0.0085
Epoch 20/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0239 - last_time_step_mse: 0.0089 - val_loss: 0.0238 - val_last_time_step_mse: 0.0086
In [149]:
model.evaluate(X_valid, Y_valid)
63/63 [==============================] - 0s 7ms/step - loss: 0.0238 - last_time_step_mse: 0.0086
Out[149]:
[0.023788688704371452, 0.008560807444155216]
In [150]:
plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [151]:
np.random.seed(43)

series = generate_time_series(1, 50 + 10)
X_new, Y_new = series[:, :50, :], series[:, 50:, :]
Y_pred = model.predict(X_new)[:, -1][..., np.newaxis]
In [152]:
plot_multiple_forecasts(X_new, Y_new, Y_pred)
plt.show()

GRUs

In [153]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.GRU(20, return_sequences=True, input_shape=[None, 1]),
    keras.layers.GRU(20, return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 4s 19ms/step - loss: 0.0738 - last_time_step_mse: 0.0655 - val_loss: 0.0538 - val_last_time_step_mse: 0.0450
Epoch 2/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0476 - last_time_step_mse: 0.0367 - val_loss: 0.0441 - val_last_time_step_mse: 0.0326
Epoch 3/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0417 - last_time_step_mse: 0.0301 - val_loss: 0.0390 - val_last_time_step_mse: 0.0275
Epoch 4/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0368 - last_time_step_mse: 0.0243 - val_loss: 0.0339 - val_last_time_step_mse: 0.0202
Epoch 5/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0326 - last_time_step_mse: 0.0180 - val_loss: 0.0312 - val_last_time_step_mse: 0.0164
Epoch 6/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0306 - last_time_step_mse: 0.0155 - val_loss: 0.0294 - val_last_time_step_mse: 0.0143
Epoch 7/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0295 - last_time_step_mse: 0.0145 - val_loss: 0.0300 - val_last_time_step_mse: 0.0162
Epoch 8/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0283 - last_time_step_mse: 0.0135 - val_loss: 0.0278 - val_last_time_step_mse: 0.0130
Epoch 9/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0276 - last_time_step_mse: 0.0130 - val_loss: 0.0273 - val_last_time_step_mse: 0.0127
Epoch 10/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0269 - last_time_step_mse: 0.0125 - val_loss: 0.0264 - val_last_time_step_mse: 0.0121
Epoch 11/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0265 - last_time_step_mse: 0.0121 - val_loss: 0.0268 - val_last_time_step_mse: 0.0135
Epoch 12/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0263 - last_time_step_mse: 0.0123 - val_loss: 0.0261 - val_last_time_step_mse: 0.0123
Epoch 13/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0258 - last_time_step_mse: 0.0116 - val_loss: 0.0254 - val_last_time_step_mse: 0.0116
Epoch 14/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0256 - last_time_step_mse: 0.0117 - val_loss: 0.0254 - val_last_time_step_mse: 0.0116
Epoch 15/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0253 - last_time_step_mse: 0.0114 - val_loss: 0.0250 - val_last_time_step_mse: 0.0112
Epoch 16/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0251 - last_time_step_mse: 0.0114 - val_loss: 0.0250 - val_last_time_step_mse: 0.0114
Epoch 17/20
219/219 [==============================] - 4s 16ms/step - loss: 0.0248 - last_time_step_mse: 0.0112 - val_loss: 0.0249 - val_last_time_step_mse: 0.0118
Epoch 18/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0245 - last_time_step_mse: 0.0110 - val_loss: 0.0244 - val_last_time_step_mse: 0.0108
Epoch 19/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0243 - last_time_step_mse: 0.0108 - val_loss: 0.0240 - val_last_time_step_mse: 0.0105
Epoch 20/20
219/219 [==============================] - 3s 16ms/step - loss: 0.0240 - last_time_step_mse: 0.0106 - val_loss: 0.0238 - val_last_time_step_mse: 0.0103
In [154]:
model.evaluate(X_valid, Y_valid)
63/63 [==============================] - 0s 7ms/step - loss: 0.0238 - last_time_step_mse: 0.0103
Out[154]:
[0.02378549985587597, 0.010262805968523026]
In [155]:
plot_learning_curves(history.history["loss"], history.history["val_loss"])
plt.show()
In [156]:
np.random.seed(43)

series = generate_time_series(1, 50 + 10)
X_new, Y_new = series[:, :50, :], series[:, 50:, :]
Y_pred = model.predict(X_new)[:, -1][..., np.newaxis]
WARNING:tensorflow:5 out of the last 508 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff642f3fe18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
In [157]:
plot_multiple_forecasts(X_new, Y_new, Y_pred)
plt.show()

Using One-Dimensional Convolutional Layers to Process Sequences

1D conv layer with kernel size 4, stride 2, VALID padding:

              |-----2-----|     |-----5---...------|     |-----23----|
        |-----1-----|     |-----4-----|   ...      |-----22----|
  |-----0----|      |-----3-----|     |---...|-----21----|
X: 0  1  2  3  4  5  6  7  8  9  10 11 12 ... 42 43 44 45 46 47 48 49
Y: 1  2  3  4  5  6  7  8  9  10 11 12 13 ... 43 44 45 46 47 48 49 50
  /10 11 12 13 14 15 16 17 18 19 20 21 22 ... 52 53 54 55 56 57 58 59

Output:

X:     0/3   2/5   4/7   6/9   8/11 10/13 .../43 42/45 44/47 46/49
Y:     4/13  6/15  8/17 10/19 12/21 14/23 .../53 46/55 48/57 50/59
In [158]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.Conv1D(filters=20, kernel_size=4, strides=2, padding="valid",
                        input_shape=[None, 1]),
    keras.layers.GRU(20, return_sequences=True),
    keras.layers.GRU(20, return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train[:, 3::2], epochs=20,
                    validation_data=(X_valid, Y_valid[:, 3::2]))
Epoch 1/20
219/219 [==============================] - 3s 15ms/step - loss: 0.0681 - last_time_step_mse: 0.0601 - val_loss: 0.0477 - val_last_time_step_mse: 0.0396
Epoch 2/20
219/219 [==============================] - 3s 11ms/step - loss: 0.0414 - last_time_step_mse: 0.0340 - val_loss: 0.0367 - val_last_time_step_mse: 0.0285
Epoch 3/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0338 - last_time_step_mse: 0.0257 - val_loss: 0.0307 - val_last_time_step_mse: 0.0218
Epoch 4/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0282 - last_time_step_mse: 0.0184 - val_loss: 0.0259 - val_last_time_step_mse: 0.0152
Epoch 5/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0249 - last_time_step_mse: 0.0143 - val_loss: 0.0246 - val_last_time_step_mse: 0.0141
Epoch 6/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0234 - last_time_step_mse: 0.0125 - val_loss: 0.0227 - val_last_time_step_mse: 0.0115
Epoch 7/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0226 - last_time_step_mse: 0.0117 - val_loss: 0.0225 - val_last_time_step_mse: 0.0116
Epoch 8/20
219/219 [==============================] - 3s 11ms/step - loss: 0.0220 - last_time_step_mse: 0.0111 - val_loss: 0.0216 - val_last_time_step_mse: 0.0105
Epoch 9/20
219/219 [==============================] - 3s 11ms/step - loss: 0.0216 - last_time_step_mse: 0.0108 - val_loss: 0.0217 - val_last_time_step_mse: 0.0109
Epoch 10/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0213 - last_time_step_mse: 0.0106 - val_loss: 0.0210 - val_last_time_step_mse: 0.0102
Epoch 11/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0210 - last_time_step_mse: 0.0102 - val_loss: 0.0208 - val_last_time_step_mse: 0.0100
Epoch 12/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0208 - last_time_step_mse: 0.0102 - val_loss: 0.0208 - val_last_time_step_mse: 0.0102
Epoch 13/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0205 - last_time_step_mse: 0.0098 - val_loss: 0.0206 - val_last_time_step_mse: 0.0101
Epoch 14/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0204 - last_time_step_mse: 0.0099 - val_loss: 0.0204 - val_last_time_step_mse: 0.0099
Epoch 15/20
219/219 [==============================] - 3s 11ms/step - loss: 0.0202 - last_time_step_mse: 0.0097 - val_loss: 0.0199 - val_last_time_step_mse: 0.0093
Epoch 16/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0200 - last_time_step_mse: 0.0097 - val_loss: 0.0201 - val_last_time_step_mse: 0.0095
Epoch 17/20
219/219 [==============================] - 2s 11ms/step - loss: 0.0196 - last_time_step_mse: 0.0093 - val_loss: 0.0197 - val_last_time_step_mse: 0.0091
Epoch 18/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0194 - last_time_step_mse: 0.0090 - val_loss: 0.0192 - val_last_time_step_mse: 0.0086
Epoch 19/20
219/219 [==============================] - 3s 12ms/step - loss: 0.0190 - last_time_step_mse: 0.0088 - val_loss: 0.0188 - val_last_time_step_mse: 0.0084
Epoch 20/20
219/219 [==============================] - 3s 11ms/step - loss: 0.0186 - last_time_step_mse: 0.0083 - val_loss: 0.0184 - val_last_time_step_mse: 0.0080

WaveNet

C2  /\ /\ /\ /\ /\ /\ /\ /\ /\ /\ /\ /\ /\.../\ /\ /\ /\ /\ /\
   \  /  \  /  \  /  \  /  \  /  \  /  \       /  \  /  \  /  \
     /    \      /    \      /    \                 /    \
C1  /\ /\ /\ /\ /\ /\ /\ /\ /\ /\ /\  /\ /.../\ /\ /\ /\ /\ /\ /\
X: 0  1  2  3  4  5  6  7  8  9  10 11 12 ... 43 44 45 46 47 48 49
Y: 1  2  3  4  5  6  7  8  9  10 11 12 13 ... 44 45 46 47 48 49 50
  /10 11 12 13 14 15 16 17 18 19 20 21 22 ... 53 54 55 56 57 58 59
In [159]:
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=[None, 1]))
for rate in (1, 2, 4, 8) * 2:
    model.add(keras.layers.Conv1D(filters=20, kernel_size=2, padding="causal",
                                  activation="relu", dilation_rate=rate))
model.add(keras.layers.Conv1D(filters=10, kernel_size=1))
model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=20,
                    validation_data=(X_valid, Y_valid))
Epoch 1/20
219/219 [==============================] - 2s 9ms/step - loss: 0.0668 - last_time_step_mse: 0.0543 - val_loss: 0.0365 - val_last_time_step_mse: 0.0230
Epoch 2/20
219/219 [==============================] - 2s 9ms/step - loss: 0.0323 - last_time_step_mse: 0.0193 - val_loss: 0.0294 - val_last_time_step_mse: 0.0167
Epoch 3/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0283 - last_time_step_mse: 0.0156 - val_loss: 0.0269 - val_last_time_step_mse: 0.0144
Epoch 4/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0261 - last_time_step_mse: 0.0136 - val_loss: 0.0254 - val_last_time_step_mse: 0.0130
Epoch 5/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0248 - last_time_step_mse: 0.0124 - val_loss: 0.0245 - val_last_time_step_mse: 0.0122
Epoch 6/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0240 - last_time_step_mse: 0.0117 - val_loss: 0.0234 - val_last_time_step_mse: 0.0108
Epoch 7/20
219/219 [==============================] - 2s 7ms/step - loss: 0.0234 - last_time_step_mse: 0.0112 - val_loss: 0.0230 - val_last_time_step_mse: 0.0108
Epoch 8/20
219/219 [==============================] - 2s 7ms/step - loss: 0.0227 - last_time_step_mse: 0.0105 - val_loss: 0.0232 - val_last_time_step_mse: 0.0110
Epoch 9/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0222 - last_time_step_mse: 0.0101 - val_loss: 0.0221 - val_last_time_step_mse: 0.0101
Epoch 10/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0219 - last_time_step_mse: 0.0099 - val_loss: 0.0213 - val_last_time_step_mse: 0.0090
Epoch 11/20
219/219 [==============================] - 2s 9ms/step - loss: 0.0213 - last_time_step_mse: 0.0092 - val_loss: 0.0210 - val_last_time_step_mse: 0.0089
Epoch 12/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0210 - last_time_step_mse: 0.0089 - val_loss: 0.0213 - val_last_time_step_mse: 0.0097
Epoch 13/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0206 - last_time_step_mse: 0.0086 - val_loss: 0.0203 - val_last_time_step_mse: 0.0082
Epoch 14/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0203 - last_time_step_mse: 0.0082 - val_loss: 0.0202 - val_last_time_step_mse: 0.0082
Epoch 15/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0201 - last_time_step_mse: 0.0081 - val_loss: 0.0198 - val_last_time_step_mse: 0.0078
Epoch 16/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0200 - last_time_step_mse: 0.0080 - val_loss: 0.0197 - val_last_time_step_mse: 0.0079
Epoch 17/20
219/219 [==============================] - 2s 9ms/step - loss: 0.0196 - last_time_step_mse: 0.0077 - val_loss: 0.0193 - val_last_time_step_mse: 0.0075
Epoch 18/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0193 - last_time_step_mse: 0.0073 - val_loss: 0.0191 - val_last_time_step_mse: 0.0073
Epoch 19/20
219/219 [==============================] - 2s 9ms/step - loss: 0.0192 - last_time_step_mse: 0.0072 - val_loss: 0.0187 - val_last_time_step_mse: 0.0069
Epoch 20/20
219/219 [==============================] - 2s 8ms/step - loss: 0.0190 - last_time_step_mse: 0.0070 - val_loss: 0.0186 - val_last_time_step_mse: 0.0069

Here is the original WaveNet defined in the paper: it uses Gated Activation Units instead of ReLU and parametrized skip connections, plus it pads with zeros on the left to avoid getting shorter and shorter sequences:

In [160]:
class GatedActivationUnit(keras.layers.Layer):
    def __init__(self, activation="tanh", **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
    def call(self, inputs):
        n_filters = inputs.shape[-1] // 2
        linear_output = self.activation(inputs[..., :n_filters])
        gate = keras.activations.sigmoid(inputs[..., n_filters:])
        return self.activation(linear_output) * gate
In [161]:
def wavenet_residual_block(inputs, n_filters, dilation_rate):
    z = keras.layers.Conv1D(2 * n_filters, kernel_size=2, padding="causal",
                            dilation_rate=dilation_rate)(inputs)
    z = GatedActivationUnit()(z)
    z = keras.layers.Conv1D(n_filters, kernel_size=1)(z)
    return keras.layers.Add()([z, inputs]), z
In [162]:
keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)

n_layers_per_block = 3 # 10 in the paper
n_blocks = 1 # 3 in the paper
n_filters = 32 # 128 in the paper
n_outputs = 10 # 256 in the paper

inputs = keras.layers.Input(shape=[None, 1])
z = keras.layers.Conv1D(n_filters, kernel_size=2, padding="causal")(inputs)
skip_to_last = []
for dilation_rate in [2**i for i in range(n_layers_per_block)] * n_blocks:
    z, skip = wavenet_residual_block(z, n_filters, dilation_rate)
    skip_to_last.append(skip)
z = keras.activations.relu(keras.layers.Add()(skip_to_last))
z = keras.layers.Conv1D(n_filters, kernel_size=1, activation="relu")(z)
Y_proba = keras.layers.Conv1D(n_outputs, kernel_size=1, activation="softmax")(z)

model = keras.models.Model(inputs=[inputs], outputs=[Y_proba])
In [163]:
model.compile(loss="mse", optimizer="adam", metrics=[last_time_step_mse])
history = model.fit(X_train, Y_train, epochs=2,
                    validation_data=(X_valid, Y_valid))
Epoch 1/2
219/219 [==============================] - 2s 11ms/step - loss: 0.1300 - last_time_step_mse: 0.1260 - val_loss: 0.1229 - val_last_time_step_mse: 0.1199
Epoch 2/2
219/219 [==============================] - 2s 9ms/step - loss: 0.1222 - last_time_step_mse: 0.1178 - val_loss: 0.1217 - val_last_time_step_mse: 0.1189

In this chapter we explored the fundamentals of RNNs and used them to process sequences (namely, time series). In the process we also looked at other ways to process sequences, including CNNs. In the next chapter we will use RNNs for Natural Language Processing, and we will learn more about RNNs (bidirectional RNNs, stateful vs stateless RNNs, Encoder–Decoders, and Attention-augmented Encoder-Decoders). We will also look at the Transformer, an Attention-only architecture.

Exercise solutions

1. to 8.

See Appendix A.

9. Tackling the SketchRNN Dataset

Exercise: Train a classification model for the SketchRNN dataset, available in TensorFlow Datasets.

The dataset is not available in TFDS yet, the pull request is still work in progress. Luckily, the data is conveniently available as TFRecords, so let's download it (it might take a while, as it's about 1 GB large, with 3,450,000 training sketches and 345,000 test sketches):

In [164]:
DOWNLOAD_ROOT = "http://download.tensorflow.org/data/"
FILENAME = "quickdraw_tutorial_dataset_v1.tar.gz"
filepath = keras.utils.get_file(FILENAME,
                                DOWNLOAD_ROOT + FILENAME,
                                cache_subdir="datasets/quickdraw",
                                extract=True)
In [165]:
quickdraw_dir = Path(filepath).parent
train_files = sorted([str(path) for path in quickdraw_dir.glob("training.tfrecord-*")])
eval_files = sorted([str(path) for path in quickdraw_dir.glob("eval.tfrecord-*")])
In [166]:
train_files
Out[166]:
['/home/work/.keras/datasets/quickdraw/training.tfrecord-00000-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00001-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00002-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00003-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00004-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00005-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00006-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00007-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00008-of-00010',
 '/home/work/.keras/datasets/quickdraw/training.tfrecord-00009-of-00010']
In [167]:
eval_files
Out[167]:
['/home/work/.keras/datasets/quickdraw/eval.tfrecord-00000-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00001-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00002-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00003-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00004-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00005-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00006-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00007-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00008-of-00010',
 '/home/work/.keras/datasets/quickdraw/eval.tfrecord-00009-of-00010']
In [168]:
with open(quickdraw_dir / "eval.tfrecord.classes") as test_classes_file:
    test_classes = test_classes_file.readlines()
    
with open(quickdraw_dir / "training.tfrecord.classes") as train_classes_file:
    train_classes = train_classes_file.readlines()
In [169]:
assert train_classes == test_classes
class_names = [name.strip().lower() for name in train_classes]
In [170]:
sorted(class_names)
Out[170]:
['aircraft carrier',
 'airplane',
 'alarm clock',
 'ambulance',
 'angel',
 'animal migration',
 'ant',
 'anvil',
 'apple',
 'arm',
 'asparagus',
 'axe',
 'backpack',
 'banana',
 'bandage',
 'barn',
 'baseball',
 'baseball bat',
 'basket',
 'basketball',
 'bat',
 'bathtub',
 'beach',
 'bear',
 'beard',
 'bed',
 'bee',
 'belt',
 'bench',
 'bicycle',
 'binoculars',
 'bird',
 'birthday cake',
 'blackberry',
 'blueberry',
 'book',
 'boomerang',
 'bottlecap',
 'bowtie',
 'bracelet',
 'brain',
 'bread',
 'bridge',
 'broccoli',
 'broom',
 'bucket',
 'bulldozer',
 'bus',
 'bush',
 'butterfly',
 'cactus',
 'cake',
 'calculator',
 'calendar',
 'camel',
 'camera',
 'camouflage',
 'campfire',
 'candle',
 'cannon',
 'canoe',
 'car',
 'carrot',
 'castle',
 'cat',
 'ceiling fan',
 'cell phone',
 'cello',
 'chair',
 'chandelier',
 'church',
 'circle',
 'clarinet',
 'clock',
 'cloud',
 'coffee cup',
 'compass',
 'computer',
 'cookie',
 'cooler',
 'couch',
 'cow',
 'crab',
 'crayon',
 'crocodile',
 'crown',
 'cruise ship',
 'cup',
 'diamond',
 'dishwasher',
 'diving board',
 'dog',
 'dolphin',
 'donut',
 'door',
 'dragon',
 'dresser',
 'drill',
 'drums',
 'duck',
 'dumbbell',
 'ear',
 'elbow',
 'elephant',
 'envelope',
 'eraser',
 'eye',
 'eyeglasses',
 'face',
 'fan',
 'feather',
 'fence',
 'finger',
 'fire hydrant',
 'fireplace',
 'firetruck',
 'fish',
 'flamingo',
 'flashlight',
 'flip flops',
 'floor lamp',
 'flower',
 'flying saucer',
 'foot',
 'fork',
 'frog',
 'frying pan',
 'garden',
 'garden hose',
 'giraffe',
 'goatee',
 'golf club',
 'grapes',
 'grass',
 'guitar',
 'hamburger',
 'hammer',
 'hand',
 'harp',
 'hat',
 'headphones',
 'hedgehog',
 'helicopter',
 'helmet',
 'hexagon',
 'hockey puck',
 'hockey stick',
 'horse',
 'hospital',
 'hot air balloon',
 'hot dog',
 'hot tub',
 'hourglass',
 'house',
 'house plant',
 'hurricane',
 'ice cream',
 'jacket',
 'jail',
 'kangaroo',
 'key',
 'keyboard',
 'knee',
 'knife',
 'ladder',
 'lantern',
 'laptop',
 'leaf',
 'leg',
 'light bulb',
 'lighter',
 'lighthouse',
 'lightning',
 'line',
 'lion',
 'lipstick',
 'lobster',
 'lollipop',
 'mailbox',
 'map',
 'marker',
 'matches',
 'megaphone',
 'mermaid',
 'microphone',
 'microwave',
 'monkey',
 'moon',
 'mosquito',
 'motorbike',
 'mountain',
 'mouse',
 'moustache',
 'mouth',
 'mug',
 'mushroom',
 'nail',
 'necklace',
 'nose',
 'ocean',
 'octagon',
 'octopus',
 'onion',
 'oven',
 'owl',
 'paint can',
 'paintbrush',
 'palm tree',
 'panda',
 'pants',
 'paper clip',
 'parachute',
 'parrot',
 'passport',
 'peanut',
 'pear',
 'peas',
 'pencil',
 'penguin',
 'piano',
 'pickup truck',
 'picture frame',
 'pig',
 'pillow',
 'pineapple',
 'pizza',
 'pliers',
 'police car',
 'pond',
 'pool',
 'popsicle',
 'postcard',
 'potato',
 'power outlet',
 'purse',
 'rabbit',
 'raccoon',
 'radio',
 'rain',
 'rainbow',
 'rake',
 'remote control',
 'rhinoceros',
 'rifle',
 'river',
 'roller coaster',
 'rollerskates',
 'sailboat',
 'sandwich',
 'saw',
 'saxophone',
 'school bus',
 'scissors',
 'scorpion',
 'screwdriver',
 'sea turtle',
 'see saw',
 'shark',
 'sheep',
 'shoe',
 'shorts',
 'shovel',
 'sink',
 'skateboard',
 'skull',
 'skyscraper',
 'sleeping bag',
 'smiley face',
 'snail',
 'snake',
 'snorkel',
 'snowflake',
 'snowman',
 'soccer ball',
 'sock',
 'speedboat',
 'spider',
 'spoon',
 'spreadsheet',
 'square',
 'squiggle',
 'squirrel',
 'stairs',
 'star',
 'steak',
 'stereo',
 'stethoscope',
 'stitches',
 'stop sign',
 'stove',
 'strawberry',
 'streetlight',
 'string bean',
 'submarine',
 'suitcase',
 'sun',
 'swan',
 'sweater',
 'swing set',
 'sword',
 'syringe',
 't-shirt',
 'table',
 'teapot',
 'teddy-bear',
 'telephone',
 'television',
 'tennis racquet',
 'tent',
 'the eiffel tower',
 'the great wall of china',
 'the mona lisa',
 'tiger',
 'toaster',
 'toe',
 'toilet',
 'tooth',
 'toothbrush',
 'toothpaste',
 'tornado',
 'tractor',
 'traffic light',
 'train',
 'tree',
 'triangle',
 'trombone',
 'truck',
 'trumpet',
 'umbrella',
 'underwear',
 'van',
 'vase',
 'violin',
 'washing machine',
 'watermelon',
 'waterslide',
 'whale',
 'wheel',
 'windmill',
 'wine bottle',
 'wine glass',
 'wristwatch',
 'yoga',
 'zebra',
 'zigzag']
In [171]:
def parse(data_batch):
    feature_descriptions = {
        "ink": tf.io.VarLenFeature(dtype=tf.float32),
        "shape": tf.io.FixedLenFeature([2], dtype=tf.int64),
        "class_index": tf.io.FixedLenFeature([1], dtype=tf.int64)
    }
    examples = tf.io.parse_example(data_batch, feature_descriptions)
    flat_sketches = tf.sparse.to_dense(examples["ink"])
    sketches = tf.reshape(flat_sketches, shape=[tf.size(data_batch), -1, 3])
    lengths = examples["shape"][:, 0]
    labels = examples["class_index"][:, 0]
    return sketches, lengths, labels
In [172]:
def quickdraw_dataset(filepaths, batch_size=32, shuffle_buffer_size=None,
                      n_parse_threads=5, n_read_threads=5, cache=False):
    dataset = tf.data.TFRecordDataset(filepaths,
                                      num_parallel_reads=n_read_threads)
    if cache:
        dataset = dataset.cache()
    if shuffle_buffer_size:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(parse, num_parallel_calls=n_parse_threads)
    return dataset.prefetch(1)
In [173]:
train_set = quickdraw_dataset(train_files, shuffle_buffer_size=10000)
valid_set = quickdraw_dataset(eval_files[:5])
test_set = quickdraw_dataset(eval_files[5:])
In [174]:
for sketches, lengths, labels in train_set.take(1):
    print("sketches =", sketches)
    print("lengths =", lengths)
    print("labels =", labels)
sketches = tf.Tensor(
[[[-0.08627451  0.11764706  0.        ]
  [-0.01176471  0.16806725  0.        ]
  [ 0.02352941  0.07563025  0.        ]
  ...
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]

 [[-0.04705882 -0.06696428  0.        ]
  [-0.09019607 -0.07142857  0.        ]
  [-0.0862745  -0.04464286  0.        ]
  ...
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]

 [[ 0.          0.          1.        ]
  [ 0.          0.          0.        ]
  [ 0.00784314  0.11320752  0.        ]
  ...
  [ 0.11764708  0.01886791  0.        ]
  [-0.03529412  0.12264156  0.        ]
  [-0.19215688  0.33962262  1.        ]]

 ...

 [[-0.21276593 -0.01960784  0.        ]
  [-0.31382978  0.00784314  0.        ]
  [-0.37234044  0.13725491  0.        ]
  ...
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]

 [[ 0.          0.4677419   0.        ]
  [-0.01176471  0.15053767  0.        ]
  [ 0.16470589  0.05376345  0.        ]
  ...
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]

 [[-0.04819274  0.01568627  0.        ]
  [-0.07228917 -0.01176471  0.        ]
  [-0.05622491 -0.03921568  0.        ]
  ...
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]], shape=(32, 104, 3), dtype=float32)
lengths = tf.Tensor(
[ 29  48 104  34  29  35  28  40  95  26  23  41  47  17  37  47  12  13
  17  41  36  23   8  15  60  32  54  38  68  30  89  36], shape=(32,), dtype=int64)
labels = tf.Tensor(
[ 95 190 163  12  77 213 216 278  25 202 310  33 327 204 260 181 337 233
 299 186  61 157 274 150   7  34  47 319 213 292 312 282], shape=(32,), dtype=int64)
In [175]:
def draw_sketch(sketch, label=None):
    origin = np.array([[0., 0., 0.]])
    sketch = np.r_[origin, sketch]
    stroke_end_indices = np.argwhere(sketch[:, -1]==1.)[:, 0]
    coordinates = np.cumsum(sketch[:, :2], axis=0)
    strokes = np.split(coordinates, stroke_end_indices + 1)
    title = class_names[label.numpy()] if label is not None else "Try to guess"
    plt.title(title)
    plt.plot(coordinates[:, 0], -coordinates[:, 1], "y:")
    for stroke in strokes:
        plt.plot(stroke[:, 0], -stroke[:, 1], ".-")
    plt.axis("off")

def draw_sketches(sketches, lengths, labels):
    n_sketches = len(sketches)
    n_cols = 4
    n_rows = (n_sketches - 1) // n_cols + 1
    plt.figure(figsize=(n_cols * 3, n_rows * 3.5))
    for index, sketch, length, label in zip(range(n_sketches), sketches, lengths, labels):
        plt.subplot(n_rows, n_cols, index + 1)
        draw_sketch(sketch[:length], label)
    plt.show()

for sketches, lengths, labels in train_set.take(1):
    draw_sketches(sketches, lengths, labels)