Kuzushiji-MNIST character recognition (neural network)

See the README.md on the GitHub repository for more information.

In [32]:
# Imports
import requests
import os
import copy
import logging
import warnings
import umap
import numpy as np
import numpy.matlib
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.model_selection import train_test_split

# Disable TensorFlow log messages
logging.getLogger('tensorflow').disabled = True
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Disable Numba warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducible randomness
tf.random.set_random_seed(71)
np.random.seed(71)

Downloading the training and test data

The Kuzushiji MNIST dataset is provided by the Centre for Open Data in Humanities (CODH) at the Research Organization of Information and Systems (ROIS).

The data (in .npz format) can be downloaded using simple HTTP requests, and deserialized into Numpy array objects:

In [33]:
BASE_URL = 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/'

urls = {
    'train': {'images': BASE_URL + 'kmnist-train-imgs.npz', 'labels': BASE_URL + 'kmnist-train-labels.npz'},
    'test': {'images': BASE_URL + 'kmnist-test-imgs.npz', 'labels': BASE_URL + 'kmnist-test-labels.npz'}
}

data = {'train': {}, 'validation': {}, 'test': {}}

for split, d in urls.items():
    for _type, url in d.items():
        path = os.path.join(os.getcwd(), 'temp.npz')
        with open(path, 'wb') as file:
            print(f"Downloading {split} {_type} from {url}")
            response = requests.get(url)
            file.write(response.content)
            data[split][_type] = np.load(path)['arr_0']
            
os.remove('temp.npz')
        
print("\nFinished downloading training and test data.")
Downloading train images from http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz
Downloading train labels from http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz
Downloading test images from http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz
Downloading test labels from http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz

Finished downloading training and test data.
In [34]:
# A mapping from character indices to their respective unicode codepoints
char_map = ['U+304A', 'U+304D' ,'U+3059' ,'U+3064' ,'U+306A' ,'U+306F' ,'U+307E' ,'U+3084' ,'U+308C' ,'U+3092']
In [35]:
# Divide the training data into training and validation data
data['train']['images'], data['validation']['images'], data['train']['labels'], data['validation']['labels'] = train_test_split(
    data['train']['images'], data['train']['labels'], test_size=10000
)

Examining the data

The Kuzushiji MNIST dataset consists of characters represented as a $28\times 28$ pixel image.

Below are ten samples of these images from the training set:

In [36]:
subset_indices = np.random.randint(50000, size=10)

fig, axs = plt.subplots(2, 5, sharex=True, sharey=True, figsize=(14, 6))
fig.subplots_adjust(hspace=.35)
cbar_ax = fig.add_axes([.91, .3, .03, .4])

for i, ax in enumerate(axs.flat):
    index = subset_indices[i]
    ax.set_title(f"Training example {index}\nCodepoint: {char_map[data['train']['labels'][index]]}")
    ax.tick_params(left=False, bottom=False)
    sns.heatmap(data['train']['images'][index], ax=ax, cbar=(i==0), cbar_ax=None if i else cbar_ax, cmap="binary")
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()

From these images alone, it is clear that the classification task will be more challenging than MNIST due to more intricacies for each individual character.

We can create UMAP embeddings (a reduced two-dimensional representation) of the examples for each dataset. This reveals the difficulty of separating the classes of the Kuzushiji-MNIST dataset when compared to MNIST.

In [37]:
# Load the MNIST dataset
tf_mnist = tfds.load(name="mnist", split=tfds.Split.TRAIN, batch_size=-1, as_dataset_kwargs={'shuffle_files': True})
np_mnist = tfds.as_numpy(tf_mnist)

# Create a UMAP dimensionality reduction object
reducer = umap.UMAP(random_state=74)

# Generate embeddings for examples of the MNIST and Kuzushiji-MNIST datasets
kmnist_emb = reducer.fit_transform(data['train']['images'].reshape(-1, 784))
mnist_emb = reducer.fit_transform(np_mnist['image'].reshape(-1, 784))
(50000, 28, 28)
(50000, 28, 28)
In [38]:
# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Create the scatter subplot for the Kuzushiji-MNIST embeddings
im1 = ax1.scatter(kmnist_emb[:, 0], kmnist_emb[:, 1], c=data['train']['labels'], cmap='Spectral', s=5)
ax1.set_title("UMAP projection of the Kuzushiji-MNIST dataset")
fig.colorbar(im1, ax=ax1, boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

# Create the scatter subplot for the MNIST embeddings
im2 = ax2.scatter(mnist_emb[:, 0], mnist_emb[:, 1], c=np_mnist['label'], cmap='Spectral', s=5)
ax2.set_title("UMAP projection of the MNIST dataset")
fig.colorbar(im2, ax=ax2, boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

# Fix layout padding
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)

Preprocessing the data

When it comes to preprocessing these images, there isn't much to be done since:

  • the dataset contains no outliers,
  • pixel values are already guaranteed to be between 0 and 255, so standard scaling is not necessary

However, potential preprocessing methods could include:

  • Thresholding
    Changing the value of a pixel to 0 or 255 depending on whether its value is below or higher than the mean pixel value of the image.
  • PCA
    Working with a selected number of principal components (which form what is known as an eigendigit)

We won't really be investigating either of these options. However, one form of preprocessing we must do is to flatten the image from a $28 \times 28$ matrix into $784 \times 1$. We will define a function to do this preprocessing for us:

In [39]:
def preprocess(images):
    # Load and flatten images into 1x784 vector of pixel RGB values (28*28)
    return images.reshape(-1, 784)

Thresholding

Thresholding often results in lost information, and as a result isn't really used in hand-written character/digit recognition.

However, it is possible to visualize this loss of information if we display the thresholded characters. The function defined below will threshold the pixel values after flattening the image:

In [40]:
def preprocess_with_threshold(images):
    # Load and flatten images into 1x784 vector of pixel RGB values (28*28)
    images = images.reshape(-1, 784)

    # Create a mean pixel threshold matrix for each image
    threshold = np.matlib.repmat(np.mean(images, axis=1), 784, 1).T

    # Apply the threshold
    images[images >= threshold] = 255
    images[images < threshold] = 0
    
    return images
In [41]:
# Apply thresholding to flattened image vector
thresholded = preprocess_with_threshold(copy.copy(data['train']['images']))

# Reshape vector to 28x28 image for visualization purposes
thresholded = thresholded.reshape(-1, 28, 28)

fig, axs = plt.subplots(2, 5, sharex=True, sharey=True, figsize=(14, 6))
fig.subplots_adjust(hspace=.35)
cbar_ax = fig.add_axes([.91, .3, .03, .4])

for i, ax in enumerate(axs.flat):
    index = subset_indices[i] if i < 5 else subset_indices[i-5]
    
    image = data['train']['images'][index] if i < 5 else thresholded[index]
    label = data['train']['labels'][index]
    
    if i == 9:
        # I don't know why the last image is rotated by 180 degrees, so rotate it back
        image = np.flip(image, axis=0) 
    
    ax.set_title(f"{'Raw' if i < 5 else 'Threshold'} ({index})\nCodepoint: {char_map[label]}")
    ax.tick_params(left=False, bottom=False)
    sns.heatmap(image, ax=ax, cbar=(i==0), cbar_ax=None if i else cbar_ax, cmap="binary")
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()

While thresholding may seem promising, it led to lower validation and test accuracies. Therefore the only preprocessing we will require is the flattening of the images.

In [42]:
# Flatten images
data['train']['images'] = preprocess(data['train']['images'])
data['validation']['images'] = preprocess(data['validation']['images'])
data['test']['images'] = preprocess(data['test']['images'])

Creating the neural network

As described in the README.md, the network consists of three hidden ReLU layers and one softmax output layer for classification.

The number of neurons in each of the hidden layers was chosen in order to conveniently sit between the $784$ input neurons and the $10$ output neurons. This is not necessarily the best network topology, and cross validation or neuroevolution (evolutionary algorithms) will definitely produce better performing networks.

But first, we must convert the Numpy arrays of the training set into TensorFlow dataset objects:

In [43]:
SHUFFLE_BUFFER_SIZE = 100
BATCH_SIZE = 20

# Transform the examples and labels into a batched TensorFlow Dataset object
train_dataset = tf.data.Dataset.from_tensor_slices((data['train']['images'], data['train']['labels']))
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE).repeat()
In [44]:
LEARNING_RATE = 0.003

# Create model structure
model = tf.keras.Sequential([
    # Input layer: 784 units (one for each pixel) + 1 bias unit
    tf.keras.layers.Flatten(input_shape=(784,), name='pixels'),
    # Hidden layer
    tf.keras.layers.Dense(512, activation='relu', name='hidden1'),
    tf.keras.layers.Dense(256, activation='relu', name='hidden2'),
    tf.keras.layers.Dense(128, activation='relu', name='hidden3'),
    # Output layer: 10 units (one for each character)
    tf.keras.layers.Dense(10, activation='softmax', name='output')
], name='Neural network')

model.compile(
    optimizer=tf.keras.optimizers.SGD(lr=LEARNING_RATE),
    # Multinomial cross-entropy between labels and softmax output layer activations
    loss='sparse_categorical_crossentropy',
    # Categorical accuracy with integer (not one-hot encoded) output classes
    metrics=['sparse_categorical_accuracy']
)

model.summary()
Model: "Neural network"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
pixels (Flatten)             (None, 784)               0         
_________________________________________________________________
hidden1 (Dense)              (None, 512)               401920    
_________________________________________________________________
hidden2 (Dense)              (None, 256)               131328    
_________________________________________________________________
hidden3 (Dense)              (None, 128)               32896     
_________________________________________________________________
output (Dense)               (None, 10)                1290      
=================================================================
Total params: 567,434
Trainable params: 567,434
Non-trainable params: 0
_________________________________________________________________

Fitting the model

In [45]:
EPOCHS = 75
TRAIN_SIZE = 50000

model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=TRAIN_SIZE//BATCH_SIZE);
Epoch 1/75
2500/2500 [==============================] - 11s 4ms/step - loss: 3.9881 - sparse_categorical_accuracy: 0.6871
Epoch 2/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.4558 - sparse_categorical_accuracy: 0.8609
Epoch 3/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.3316 - sparse_categorical_accuracy: 0.9005
Epoch 4/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.2624 - sparse_categorical_accuracy: 0.9210
Epoch 5/75
2500/2500 [==============================] - 12s 5ms/step - loss: 0.2114 - sparse_categorical_accuracy: 0.9367
Epoch 6/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.1757 - sparse_categorical_accuracy: 0.9471
Epoch 7/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.1458 - sparse_categorical_accuracy: 0.9557
Epoch 8/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.1223 - sparse_categorical_accuracy: 0.9620
Epoch 9/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.1054 - sparse_categorical_accuracy: 0.9675
Epoch 10/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0886 - sparse_categorical_accuracy: 0.9729
Epoch 11/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0765 - sparse_categorical_accuracy: 0.9771
Epoch 12/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0715 - sparse_categorical_accuracy: 0.9778
Epoch 13/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0600 - sparse_categorical_accuracy: 0.9814
Epoch 14/75
2500/2500 [==============================] - 7s 3ms/step - loss: 0.0499 - sparse_categorical_accuracy: 0.9850
Epoch 15/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0459 - sparse_categorical_accuracy: 0.9856
Epoch 16/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9851
Epoch 17/75
2500/2500 [==============================] - 7s 3ms/step - loss: 0.0449 - sparse_categorical_accuracy: 0.9861
Epoch 18/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0379 - sparse_categorical_accuracy: 0.9879
Epoch 19/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0336 - sparse_categorical_accuracy: 0.9894A: 1s - loss: 0.0333 - sparse
Epoch 20/75
2500/2500 [==============================] - 12s 5ms/step - loss: 0.0308 - sparse_categorical_accuracy: 0.9904
Epoch 21/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0309 - sparse_categorical_accuracy: 0.9902
Epoch 22/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9914
Epoch 23/75
2500/2500 [==============================] - 19s 8ms/step - loss: 0.0269 - sparse_categorical_accuracy: 0.9915: 4s - loss
Epoch 24/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9921
Epoch 25/75
2500/2500 [==============================] - 14s 6ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9915
Epoch 26/75
2500/2500 [==============================] - 14s 6ms/step - loss: 0.0323 - sparse_categorical_accuracy: 0.9900
Epoch 27/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0266 - sparse_categorical_accuracy: 0.9914
Epoch 28/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9932
Epoch 29/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0240 - sparse_categorical_accuracy: 0.9926A: 1s - loss: 0.0184 -
Epoch 30/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9897
Epoch 31/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0234 - sparse_categorical_accuracy: 0.9926
Epoch 32/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0176 - sparse_categorical_accuracy: 0.9942
Epoch 33/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9933
Epoch 34/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0151 - sparse_categorical_accuracy: 0.9952
Epoch 35/75
2500/2500 [==============================] - 7s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9959
Epoch 36/75
2500/2500 [==============================] - 13s 5ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9964
Epoch 37/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9967
Epoch 38/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0089 - sparse_categorical_accuracy: 0.9974
Epoch 39/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0063 - sparse_categorical_accuracy: 0.9978
Epoch 40/75
2500/2500 [==============================] - 7s 3ms/step - loss: 0.0052 - sparse_categorical_accuracy: 0.9982
Epoch 41/75
2500/2500 [==============================] - 8s 3ms/step - loss: 0.0048 - sparse_categorical_accuracy: 0.9983
Epoch 42/75
2500/2500 [==============================] - 12s 5ms/step - loss: 0.0048 - sparse_categorical_accuracy: 0.9983
Epoch 43/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0057 - sparse_categorical_accuracy: 0.9981A: 0s - loss: 0.0058 - sparse_categorical_a
Epoch 44/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0049 - sparse_categorical_accuracy: 0.9983
Epoch 45/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0046 - sparse_categorical_accuracy: 0.9983A: 1s - loss: 0.0047 - sparse_catego
Epoch 46/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0051 - sparse_categorical_accuracy: 0.9981
Epoch 47/75
2500/2500 [==============================] - 9s 3ms/step - loss: 0.0045 - sparse_categorical_accuracy: 0.9983
Epoch 48/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0043 - sparse_categorical_accuracy: 0.9983
Epoch 49/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0044 - sparse_categorical_accuracy: 0.9984
Epoch 50/75
2500/2500 [==============================] - 9s 3ms/step - loss: 0.0041 - sparse_categorical_accuracy: 0.9984A: 2s - lo
Epoch 51/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0041 - sparse_categorical_accuracy: 0.9984
Epoch 52/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0041 - sparse_categorical_accuracy: 0.9984
Epoch 53/75
2500/2500 [==============================] - 16s 6ms/step - loss: 0.0042 - sparse_categorical_accuracy: 0.9985
Epoch 54/75
2500/2500 [==============================] - 13s 5ms/step - loss: 0.0041 - sparse_categorical_accuracy: 0.9984
Epoch 55/75
2500/2500 [==============================] - 11s 5ms/step - loss: 0.0039 - sparse_categorical_accuracy: 0.9985
Epoch 56/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0038 - sparse_categorical_accuracy: 0.9985
Epoch 57/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0039 - sparse_categorical_accuracy: 0.9985
Epoch 58/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0038 - sparse_categorical_accuracy: 0.9985: 0s - loss: 0.0037 - sparse_cate
Epoch 59/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0037 - sparse_categorical_accuracy: 0.9986
Epoch 60/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0037 - sparse_categorical_accuracy: 0.9986
Epoch 61/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0036 - sparse_categorical_accuracy: 0.9986
Epoch 62/75
2500/2500 [==============================] - 12s 5ms/step - loss: 0.0035 - sparse_categorical_accuracy: 0.9986
Epoch 63/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0037 - sparse_categorical_accuracy: 0.9986
Epoch 64/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0035 - sparse_categorical_accuracy: 0.9986
Epoch 65/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0036 - sparse_categorical_accuracy: 0.9987
Epoch 66/75
2500/2500 [==============================] - 11s 4ms/step - loss: 0.0034 - sparse_categorical_accuracy: 0.9987
Epoch 67/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0033 - sparse_categorical_accuracy: 0.9987
Epoch 68/75
2500/2500 [==============================] - 10s 4ms/step - loss: 0.0033 - sparse_categorical_accuracy: 0.9987
Epoch 69/75
2500/2500 [==============================] - 14s 6ms/step - loss: 0.0034 - sparse_categorical_accuracy: 0.9987
Epoch 70/75
2500/2500 [==============================] - 12s 5ms/step - loss: 0.0045 - sparse_categorical_accuracy: 0.9982
Epoch 71/75
2500/2500 [==============================] - 11s 5ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9924
Epoch 72/75
2500/2500 [==============================] - 9s 3ms/step - loss: 0.0532 - sparse_categorical_accuracy: 0.9856
Epoch 73/75
2500/2500 [==============================] - 9s 3ms/step - loss: 0.0418 - sparse_categorical_accuracy: 0.9877
Epoch 74/75
2500/2500 [==============================] - 9s 4ms/step - loss: 0.0311 - sparse_categorical_accuracy: 0.9910
Epoch 75/75
2500/2500 [==============================] - 14s 6ms/step - loss: 0.0244 - sparse_categorical_accuracy: 0.9922

Evaluating the model

In [46]:
# Evaluate the performance of the model on the validation set
validation_dataset = tf.data.Dataset.from_tensor_slices((data['validation']['images'], data['validation']['labels']))
validation_dataset = validation_dataset.batch(BATCH_SIZE)
print("\n\033[1;32mEvaluating performance on validation set:\033[0m")
model.evaluate(validation_dataset)

# Evaluate the performance of the model on the test set
test_dataset = tf.data.Dataset.from_tensor_slices((data['test']['images'], data['test']['labels']))
test_dataset = test_dataset.batch(BATCH_SIZE)
print("\n\033[1;32mEvaluating performance on test set:\033[0m")
model.evaluate(test_dataset);
Evaluating performance on validation set:
500/500 [==============================] - 1s 3ms/step - loss: 0.4802 - sparse_categorical_accuracy: 0.9368

Evaluating performance on test set:
500/500 [==============================] - 1s 3ms/step - loss: 1.1447 - sparse_categorical_accuracy: 0.8619