Sketch Recognition Convolutional Neural Network (CNN)

Experiment overview

In this experiment we will build a Convolutional Neural Network (CNN) model using Tensorflow to recognize handwritten sketches by using a quick-draw dataset.

sketch_recognition_cnn.png

Import dependencies

In [2]:
# Selecting Tensorflow version v2 (the command is relevant for Colab only).
# %tensorflow_version 2.x
In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import math
import datetime
import platform
import pathlib
import random

print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)
Python version: 3.7.6
Tensorflow version: 2.1.0
Keras version: 2.2.4-tf
In [4]:
cache_dir = 'tmp';
In [5]:
# Create cache folder.
!mkdir tmp
mkdir: tmp: File exists

Load dataset

In [6]:
# List all available datasets to see how the wikipedia dataset is called.
tfds.list_builders()
Out[6]:
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'amazon_us_reviews',
 'arc',
 'bair_robot_pushing_small',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'cos_e',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'diabetic_retinopathy_detection',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'dummy_dataset_shared_generator',
 'dummy_mnist',
 'emnist',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'gap',
 'gigaword',
 'glue',
 'groove',
 'higgs',
 'horses_or_humans',
 'i_naturalist2017',
 'image_label_folder',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet_resized',
 'imagenette',
 'imdb_reviews',
 'iris',
 'kitti',
 'kmnist',
 'lfw',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mnist',
 'mnist_corrupted',
 'movie_rationales',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'newsroom',
 'nsynth',
 'omniglot',
 'open_images_v4',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'pet_finder',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'quickdraw_bitmap',
 'reddit_tifu',
 'resisc45',
 'rock_paper_scissors',
 'rock_you',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'shapes3d',
 'smallnorb',
 'snli',
 'so2sat',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tf_flowers',
 'the300w_lp',
 'titanic',
 'trivia_qa',
 'uc_merced',
 'ucf101',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'wider_face',
 'wikihow',
 'wikipedia',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'xnli',
 'xsum']
In [7]:
DATASET_NAME = 'quickdraw_bitmap'

dataset, dataset_info = tfds.load(
    name=DATASET_NAME,
    data_dir=cache_dir,
    with_info=True,
    split=tfds.Split.TRAIN,
)

Explore dataset

In [166]:
print(dataset_info)
tfds.core.DatasetInfo(
    name='quickdraw_bitmap',
    version=3.0.0,
    description='The Quick Draw Dataset is a collection of 50 million drawings across 345 categories, contributed by players of the game Quick, Draw!. The bitmap dataset contains these drawings converted from vector format into 28x28 grayscale images',
    homepage='https://github.com/googlecreativelab/quickdraw-dataset',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=345),
    }),
    total_num_examples=50426266,
    splits={
        'train': 50426266,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{DBLP:journals/corr/HaE17,
      author    = {David Ha and
                   Douglas Eck},
      title     = {A Neural Representation of Sketch Drawings},
      journal   = {CoRR},
      volume    = {abs/1704.03477},
      year      = {2017},
      url       = {http://arxiv.org/abs/1704.03477},
      archivePrefix = {arXiv},
      eprint    = {1704.03477},
      timestamp = {Mon, 13 Aug 2018 16:48:30 +0200},
      biburl    = {https://dblp.org/rec/bib/journals/corr/HaE17},
      bibsource = {dblp computer science bibliography, https://dblp.org}
    }""",
    redistribution_info=,
)

In [167]:
image_shape = dataset_info.features['image'].shape
num_classes = dataset_info.features['label'].num_classes
num_examples = dataset_info.splits['train'].num_examples

print('num_examples: ', num_examples)
print('image_shape: ', image_shape)
print('num_classes: ', num_classes)
num_examples:  50426266
image_shape:  (28, 28, 1)
num_classes:  345
In [168]:
label_index_to_string = dataset_info.features['label'].int2str

classes = []

for class_index in range(num_classes):
    classes.append(label_index_to_string(class_index))
    
print('classes num:', len(classes))
print('classes:\n\n', classes)
classes num: 345
classes:

 ['aircraft carrier', 'airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball bat', 'baseball', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden hose', 'garden', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house plant', 'house', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'rifle', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 'syringe', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']
In [162]:
print(dataset)
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
In [163]:
fig = tfds.show_examples(dataset_info, dataset)
In [203]:
def dataset_preview(dataset, image_shape, preview_images_num=100):
    num_cells = math.ceil(math.sqrt(preview_images_num))
    plt.figure(figsize=(17, 17))
    image_size = image_shape[0]
    
    for image_index, example in enumerate(dataset.take(preview_images_num)):
        image = example['image']
        label = example['label']
        
        class_index = label.numpy()
        class_name = classes[class_index]
        
        plt.subplot(num_cells, num_cells, image_index + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(
            np.reshape(image, (image_size, image_size)),
            cmap=plt.cm.binary
        )
        plt.xlabel('{} ({})'.format(class_name, class_index))
    plt.show()
In [213]:
def dataset_normalized_preview(dataset, image_shape, preview_images_num=100):
    num_cells = math.ceil(math.sqrt(preview_images_num))
    plt.figure(figsize=(17, 17))
    image_size = image_shape[0]
    
    for image_index, example in enumerate(dataset.take(preview_images_num)):
        image = example[0]
        label = tf.math.argmax(example[1])
        
        class_index = label.numpy()
        class_name = label_index_to_string(class_index)
        
        plt.subplot(num_cells, num_cells, image_index + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(
            np.reshape(image, (image_size, image_size)),
            cmap=plt.cm.binary
        )
        plt.xlabel('{} ({})'.format(class_name, class_index))
    plt.show()
In [205]:
def dataset_head(ds):
    for example in ds.take(1):
        image = example['image']
        label = example['label']

        class_index = label.numpy()
        class_name = label_index_to_string(class_index)

        print('{} ({})'.format(class_name, class_index), '\n')
        print('Image shape: ', image.shape, '\n')
        print(np.reshape(image.numpy(), (28, 28)), '\n')
In [206]:
dataset_head(dataset)
backpack (12) 

Image shape:  (28, 28, 1) 

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   3  31  63  95 115
   86  50   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0  75 183 176 152  10  54 182 222 250 255 255 255 255
  255 255 231 151  61   0   0   0   0   0]
 [  0   0   0   0 129 255 216 202 255  98 220 222 157 125  93  60  28   8
   38  82 162 239 254 108   0   0   0   0]
 [  0   0   0  33 250 158   4   0 204 209 255 236 105   0   0   0   0   0
    0   0   0   6 164 254  88   0   0   0]
 [  0   0   0 158 243  30 155 235 211 255 254 185 255  97   0   0   0   0
    0   0   0   0   3 193 229  10   0   0]
 [  0   0  38 251 134 153 251 162 252 245 225   1 180 241   6   0   0   0
    0   0   0   0   0  69 255  62   0   0]
 [  0   0 150 240  26 245 145   0 116 255 149   0  75 255  58   0   0   0
    0   0   0   0   0  29 255  94   0   0]
 [  0   0 209 173  77 255  62   0  62 255  68   0  14 251 122   0   0   0
    0   0   0   0   0   2 249 128   0   0]
 [  0   0 240 138 151 236   4   0  76 255  46   0   0 221 157   0   0   0
    0   0   0   0   0   0 218 160   0   0]
 [  0   0 251 127 188 188   0   0  76 255  46   0   0 213 164   0   4  45
    0   0   0   0   0   0 186 193   0   0]
 [  0   5 255 117 199 176   0   0  76 255  46   0   0 205 173   0  75 251
    9   0  44 154   2   0 158 226   0   0]
 [  0  14 255 108 209 166   0   0  76 255  46   0   0 160 243 102  44  61
   34  35  98 199 132 170 251 252   6   0]
 [  0  23 255  99 220 156   0   0  76 255  46   0   0 120 249 255 255 255
  255 255 255 255 247 214 202 255  24   0]
 [  0  17 255 108 221 162   0   0  72 255  51   0   0 129 247  31  82  85
   85  85  62  28   1   0 108 255  13   0]
 [  0   0 237 143 159 232   3   0  38 255  87   0   0 129 247   0   0   0
    0   0   0   0   0   0 121 253   2   0]
 [  0   0 198 199  73 255  78   0   4 248 127   0   0 129 247   0   0   0
   20  63  39   8   0   0 134 242   0   0]
 [  0   0  96 255 101 220 233  96   1 213 166   0   0 129 247   0  95 204
  253 255 255 255 230 199 218 232   0   0]
 [  0   0   1 180 252 136 189 255 106 173 207   0   0 129 247   0 239 228
  116  63  84 115 147 185 255 231   0   0]
 [  0   0   0   8 161 255 163  76  22 133 246   2   0 129 247   0 198 177
    0   0   0   0   0  49 255 220   0   0]
 [  0   0   0   0   0 113 251 242 196 240 255  37   0 129 247   0 208 168
    0   0   0   0   0 130 254 190   0   0]
 [  0   0   0   0   0   0  56 141 184 165 248 159   0 126 250   0 174 224
   13   0   0   0   0 219 255 115   0   0]
 [  0   0   0   0   0   0   0   0   0   0 127 254  60 116 255   5  69 252
  230  96   2   0  46 255 253  30   0   0]
 [  0   0   0   0   0   0   0   0   0   0   9 215 241 161 255  15   0  46
  190 255 206 169 244 255 117   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  23 203 255 255 229 217 204
  191 227 255 255 251 156   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1 108 255 183 159 170
  185 170 112  65   4   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   3  97  23   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]] 

In [207]:
dataset_preview(dataset, image_shape)

Normalize dataset

In [208]:
def normalize_example(example):
    image = example['image']
    label = example['label']
    label = tf.one_hot(label, len(classes))
    
    image = tf.math.divide(image, 255)
    
    return (image, label)
In [209]:
def augment_example(image, label):
    image = tf.image.random_flip_left_right(image)
    return (image, label)
In [210]:
dataset_normalized = dataset.map(normalize_example).map(augment_example)
In [211]:
for (image, label) in dataset_normalized.take(1):
    class_index = tf.math.argmax(label).numpy()
    class_name = label_index_to_string(class_index)

    print('{} ({})'.format(class_name, class_index), '\n')
    print('Image shape: ', image.shape, '\n')
    print(np.reshape(image.numpy(), (28, 28)), '\n')
backpack (12) 

Image shape:  (28, 28, 1) 

[[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.01176471 0.12156863 0.24705882 0.37254903 0.4509804
  0.3372549  0.19607843 0.00784314 0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.29411766
  0.7176471  0.6901961  0.59607846 0.03921569 0.21176471 0.7137255
  0.87058824 0.98039216 1.         1.         1.         1.
  1.         1.         0.90588236 0.5921569  0.23921569 0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.5058824  1.
  0.84705883 0.7921569  1.         0.38431373 0.8627451  0.87058824
  0.6156863  0.49019608 0.3647059  0.23529412 0.10980392 0.03137255
  0.14901961 0.32156864 0.63529414 0.9372549  0.99607843 0.42352942
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.12941177 0.98039216 0.61960787
  0.01568628 0.         0.8        0.81960785 1.         0.9254902
  0.4117647  0.         0.         0.         0.         0.
  0.         0.         0.         0.02352941 0.6431373  0.99607843
  0.34509805 0.         0.         0.        ]
 [0.         0.         0.         0.61960787 0.9529412  0.11764706
  0.60784316 0.92156863 0.827451   1.         0.99607843 0.7254902
  1.         0.38039216 0.         0.         0.         0.
  0.         0.         0.         0.         0.01176471 0.75686276
  0.8980392  0.03921569 0.         0.        ]
 [0.         0.         0.14901961 0.9843137  0.5254902  0.6
  0.9843137  0.63529414 0.9882353  0.9607843  0.88235295 0.00392157
  0.7058824  0.94509804 0.02352941 0.         0.         0.
  0.         0.         0.         0.         0.         0.27058825
  1.         0.24313726 0.         0.        ]
 [0.         0.         0.5882353  0.9411765  0.10196079 0.9607843
  0.5686275  0.         0.45490196 1.         0.58431375 0.
  0.29411766 1.         0.22745098 0.         0.         0.
  0.         0.         0.         0.         0.         0.11372549
  1.         0.36862746 0.         0.        ]
 [0.         0.         0.81960785 0.6784314  0.3019608  1.
  0.24313726 0.         0.24313726 1.         0.26666668 0.
  0.05490196 0.9843137  0.47843137 0.         0.         0.
  0.         0.         0.         0.         0.         0.00784314
  0.9764706  0.5019608  0.         0.        ]
 [0.         0.         0.9411765  0.5411765  0.5921569  0.9254902
  0.01568628 0.         0.29803923 1.         0.18039216 0.
  0.         0.8666667  0.6156863  0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.85490197 0.627451   0.         0.        ]
 [0.         0.         0.9843137  0.49803922 0.7372549  0.7372549
  0.         0.         0.29803923 1.         0.18039216 0.
  0.         0.8352941  0.6431373  0.         0.01568628 0.1764706
  0.         0.         0.         0.         0.         0.
  0.7294118  0.75686276 0.         0.        ]
 [0.         0.01960784 1.         0.45882353 0.78039217 0.6901961
  0.         0.         0.29803923 1.         0.18039216 0.
  0.         0.8039216  0.6784314  0.         0.29411766 0.9843137
  0.03529412 0.         0.17254902 0.6039216  0.00784314 0.
  0.61960787 0.8862745  0.         0.        ]
 [0.         0.05490196 1.         0.42352942 0.81960785 0.6509804
  0.         0.         0.29803923 1.         0.18039216 0.
  0.         0.627451   0.9529412  0.4        0.17254902 0.23921569
  0.13333334 0.13725491 0.38431373 0.78039217 0.5176471  0.6666667
  0.9843137  0.9882353  0.02352941 0.        ]
 [0.         0.09019608 1.         0.3882353  0.8627451  0.6117647
  0.         0.         0.29803923 1.         0.18039216 0.
  0.         0.47058824 0.9764706  1.         1.         1.
  1.         1.         1.         1.         0.96862745 0.8392157
  0.7921569  1.         0.09411765 0.        ]
 [0.         0.06666667 1.         0.42352942 0.8666667  0.63529414
  0.         0.         0.28235295 1.         0.2        0.
  0.         0.5058824  0.96862745 0.12156863 0.32156864 0.33333334
  0.33333334 0.33333334 0.24313726 0.10980392 0.00392157 0.
  0.42352942 1.         0.05098039 0.        ]
 [0.         0.         0.92941177 0.56078434 0.62352943 0.9098039
  0.01176471 0.         0.14901961 1.         0.34117648 0.
  0.         0.5058824  0.96862745 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.4745098  0.99215686 0.00784314 0.        ]
 [0.         0.         0.7764706  0.78039217 0.28627452 1.
  0.30588236 0.         0.01568628 0.972549   0.49803922 0.
  0.         0.5058824  0.96862745 0.         0.         0.
  0.07843138 0.24705882 0.15294118 0.03137255 0.         0.
  0.5254902  0.9490196  0.         0.        ]
 [0.         0.         0.3764706  1.         0.39607844 0.8627451
  0.9137255  0.3764706  0.00392157 0.8352941  0.6509804  0.
  0.         0.5058824  0.96862745 0.         0.37254903 0.8
  0.99215686 1.         1.         1.         0.9019608  0.78039217
  0.85490197 0.9098039  0.         0.        ]
 [0.         0.         0.00392157 0.7058824  0.9882353  0.53333336
  0.7411765  1.         0.41568628 0.6784314  0.8117647  0.
  0.         0.5058824  0.96862745 0.         0.9372549  0.89411765
  0.45490196 0.24705882 0.32941177 0.4509804  0.5764706  0.7254902
  1.         0.90588236 0.         0.        ]
 [0.         0.         0.         0.03137255 0.6313726  1.
  0.6392157  0.29803923 0.08627451 0.52156866 0.9647059  0.00784314
  0.         0.5058824  0.96862745 0.         0.7764706  0.69411767
  0.         0.         0.         0.         0.         0.19215687
  1.         0.8627451  0.         0.        ]
 [0.         0.         0.         0.         0.         0.44313726
  0.9843137  0.9490196  0.76862746 0.9411765  1.         0.14509805
  0.         0.5058824  0.96862745 0.         0.8156863  0.65882355
  0.         0.         0.         0.         0.         0.50980395
  0.99607843 0.74509805 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.21960784 0.5529412  0.72156864 0.64705884 0.972549   0.62352943
  0.         0.49411765 0.98039216 0.         0.68235296 0.8784314
  0.05098039 0.         0.         0.         0.         0.85882354
  1.         0.4509804  0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.49803922 0.99607843
  0.23529412 0.45490196 1.         0.01960784 0.27058825 0.9882353
  0.9019608  0.3764706  0.00784314 0.         0.18039216 1.
  0.99215686 0.11764706 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.03529412 0.84313726
  0.94509804 0.6313726  1.         0.05882353 0.         0.18039216
  0.74509805 1.         0.80784315 0.6627451  0.95686275 1.
  0.45882353 0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.09019608
  0.79607844 1.         1.         0.8980392  0.8509804  0.8
  0.7490196  0.8901961  1.         1.         0.9843137  0.6117647
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.00392157 0.42352942 1.         0.7176471  0.62352943 0.6666667
  0.7254902  0.6666667  0.4392157  0.25490198 0.01568628 0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.01176471 0.38039216 0.09019608 0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]] 

In [214]:
dataset_normalized_preview(dataset_normalized, image_shape)

Prepare Train/Validation/Test dataset splits

In [215]:
# A quick example of how we're going to split the dataset for train/test/validation subsets.
tmp_ds = tf.data.Dataset.range(10)
print('tmp_ds:', list(tmp_ds.as_numpy_iterator()))

tmp_ds_test = tmp_ds.take(2)
print('tmp_ds_test:', list(tmp_ds_test.as_numpy_iterator()))

tmp_ds_val = tmp_ds.skip(2).take(3)
print('tmp_ds_val:', list(tmp_ds_val.as_numpy_iterator()))

tmp_ds_train = tmp_ds.skip(2 + 3)
print('tmp_ds_train:', list(tmp_ds_train.as_numpy_iterator()))
tmp_ds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
tmp_ds_test: [0, 1]
tmp_ds_val: [2, 3, 4]
tmp_ds_train: [5, 6, 7, 8, 9]
In [378]:
# Dataset split
test_dataset_batches = 1
val_dataset_batches = 1

# Dataset batching
batch_size = 2000
prefetch_buffer_batches = 10

# Training
epochs = 40
steps_per_epoch = 500
In [379]:
dataset_batched = dataset_normalized.batch(batch_size=batch_size)
In [380]:
# TEST dataset.
dataset_test = dataset_batched \
    .take(test_dataset_batches)

# VALIDATION dataset.
dataset_val = dataset_batched \
    .skip(test_dataset_batches) \
    .take(val_dataset_batches)

# TRAIN dataset.
dataset_train = dataset_batched \
    .skip(test_dataset_batches + val_dataset_batches) \
    .prefetch(buffer_size=prefetch_buffer_batches) \
    .repeat()
In [381]:
for (image_test, label_test) in dataset_test.take(1):
    print('label_test.shape: ', label_test.shape)
    print('image_test.shape: ', image_test.shape)
    
print()    
    
for (image_val, label_val) in dataset_val.take(1):
    print('label_val.shape: ', label_val.shape)
    print('image_val.shape: ', image_val.shape)    
    
print()    
    
for (image_train, label_train) in dataset_train.take(1):
    print('label_train.shape: ', label_train.shape)
    print('image_train.shape: ', image_train.shape)    
label_test.shape:  (2000, 345)
image_test.shape:  (2000, 28, 28, 1)

label_val.shape:  (2000, 345)
image_val.shape:  (2000, 28, 28, 1)

label_train.shape:  (2000, 345)
image_train.shape:  (2000, 28, 28, 1)
In [382]:
# Calculate how many times the network will "see" each class during one epoch of training
# given specific dataset (batches) and number of steps per epoch.
def get_dataset_classes_hist(dataset, classes, batches_num):
    mentions = {class_name: 0 for class_name in classes}
    for examples, labels in dataset.take(batches_num):
        for label in labels:
            class_name = classes[tf.math.argmax(label).numpy()]
            mentions[class_name] += 1    
    return mentions;
In [383]:
mentions = get_dataset_classes_hist(
    dataset_train,
    classes,
    batches_num=steps_per_epoch
)
In [384]:
for class_name in mentions:
    print('{:15s}: {}'.format(class_name, mentions[class_name]))
aircraft carrier: 2295
airplane       : 3030
alarm clock    : 2458
ambulance      : 2842
angel          : 2926
animal migration: 2706
ant            : 2548
anvil          : 2553
apple          : 2904
arm            : 2373
asparagus      : 3351
axe            : 2473
backpack       : 2429
banana         : 6052
bandage        : 2935
barn           : 2956
baseball bat   : 2381
baseball       : 2668
basket         : 2278
basketball     : 2662
bat            : 2371
bathtub        : 3429
beach          : 2515
bear           : 2679
beard          : 3318
bed            : 2346
bee            : 2417
belt           : 3876
bench          : 2531
bicycle        : 2501
binoculars     : 2432
bird           : 2619
birthday cake  : 3036
blackberry     : 2570
blueberry      : 2537
book           : 2389
boomerang      : 2918
bottlecap      : 3091
bowtie         : 2549
bracelet       : 2394
brain          : 2770
bread          : 2356
bridge         : 2638
broccoli       : 2613
broom          : 2347
bucket         : 2423
bulldozer      : 3645
bus            : 3293
bush           : 2425
butterfly      : 2333
cactus         : 2663
cake           : 2415
calculator     : 2543
calendar       : 6369
camel          : 2429
camera         : 2521
camouflage     : 3377
campfire       : 2694
candle         : 2764
cannon         : 2768
canoe          : 2500
car            : 3560
carrot         : 2643
castle         : 2380
cat            : 2385
ceiling fan    : 2282
cell phone     : 2415
cello          : 3012
chair          : 4409
chandelier     : 3392
church         : 3305
circle         : 2414
clarinet       : 2518
clock          : 2371
cloud          : 2465
coffee cup     : 3613
compass        : 2589
computer       : 2450
cookie         : 2664
cooler         : 5334
couch          : 2439
cow            : 2458
crab           : 2519
crayon         : 2650
crocodile      : 2568
crown          : 2745
cruise ship    : 2515
cup            : 2624
diamond        : 2617
dishwasher     : 3346
diving board   : 5817
dog            : 3000
dolphin        : 2446
donut          : 2806
door           : 2388
dragon         : 2434
dresser        : 2469
drill          : 2673
drums          : 2746
duck           : 2658
dumbbell       : 3085
ear            : 2474
elbow          : 2553
elephant       : 2467
envelope       : 2644
eraser         : 2339
eye            : 2579
eyeglasses     : 4420
face           : 3250
fan            : 2657
feather        : 2352
fence          : 2584
finger         : 3272
fire hydrant   : 2710
fireplace      : 3101
firetruck      : 4426
fish           : 2589
flamingo       : 2463
flashlight     : 4687
flip flops     : 2406
floor lamp     : 3254
flower         : 2948
flying saucer  : 2939
foot           : 3890
fork           : 2543
frog           : 3120
frying pan     : 2392
garden hose    : 2384
garden         : 3225
giraffe        : 2511
goatee         : 3803
golf club      : 3867
grapes         : 3037
grass          : 2417
guitar         : 2380
hamburger      : 2604
hammer         : 2313
hand           : 5825
harp           : 5799
hat            : 4451
headphones     : 2320
hedgehog       : 2352
helicopter     : 3123
helmet         : 2452
hexagon        : 2874
hockey puck    : 4059
hockey stick   : 2494
horse          : 3555
hospital       : 3317
hot air balloon: 2532
hot dog        : 3639
hot tub        : 2438
hourglass      : 2705
house plant    : 2422
house          : 2700
hurricane      : 2659
ice cream      : 2488
jacket         : 4193
jail           : 2410
kangaroo       : 3356
key            : 3121
keyboard       : 3697
knee           : 5312
knife          : 3459
ladder         : 2419
lantern        : 2968
laptop         : 5222
leaf           : 2486
leg            : 2329
light bulb     : 2441
lighter        : 2353
lighthouse     : 3259
lightning      : 3075
line           : 2831
lion           : 2396
lipstick       : 2607
lobster        : 2780
lollipop       : 2525
mailbox        : 2579
map            : 2366
marker         : 6420
matches        : 2820
megaphone      : 2790
mermaid        : 3598
microphone     : 2374
microwave      : 2629
monkey         : 2488
moon           : 2394
mosquito       : 2402
motorbike      : 3379
mountain       : 2463
mouse          : 3548
moustache      : 3572
mouth          : 2591
mug            : 3131
mushroom       : 2815
nail           : 3209
necklace       : 2360
nose           : 3790
ocean          : 2555
octagon        : 3217
octopus        : 2965
onion          : 2659
oven           : 4101
owl            : 3321
paint can      : 2341
paintbrush     : 3673
palm tree      : 2384
panda          : 2177
pants          : 2892
paper clip     : 2511
parachute      : 2499
parrot         : 3648
passport       : 2981
peanut         : 2518
pear           : 2332
peas           : 3167
pencil         : 2523
penguin        : 5005
piano          : 2270
pickup truck   : 2634
picture frame  : 2455
pig            : 3771
pillow         : 2338
pineapple      : 2504
pizza          : 2636
pliers         : 3528
police car     : 2724
pond           : 2381
pool           : 2637
popsicle       : 2486
postcard       : 2515
potato         : 6675
power outlet   : 3284
purse          : 2421
rabbit         : 3032
raccoon        : 2289
radio          : 2653
rain           : 2624
rainbow        : 2489
rake           : 3160
remote control : 2320
rhinoceros     : 3760
rifle          : 3453
river          : 2629
roller coaster : 2885
rollerskates   : 2340
sailboat       : 2706
sandwich       : 2607
saw            : 2456
saxophone      : 2244
school bus     : 2417
scissors       : 2451
scorpion       : 3334
screwdriver    : 2344
sea turtle     : 2347
see saw        : 2599
shark          : 2529
sheep          : 2519
shoe           : 2435
shorts         : 2460
shovel         : 2338
sink           : 4133
skateboard     : 2572
skull          : 2513
skyscraper     : 3606
sleeping bag   : 2278
smiley face    : 2512
snail          : 2616
snake          : 2410
snorkel        : 3065
snowflake      : 2312
snowman        : 6759
soccer ball    : 2457
sock           : 4093
speedboat      : 3768
spider         : 4084
spoon          : 2531
spreadsheet    : 3312
square         : 2414
squiggle       : 2370
squirrel       : 2943
stairs         : 2599
star           : 2706
steak          : 2410
stereo         : 2518
stethoscope    : 3000
stitches       : 2519
stop sign      : 2339
stove          : 2348
strawberry     : 2351
streetlight    : 2533
string bean    : 2411
submarine      : 2522
suitcase       : 2513
sun            : 2716
swan           : 3096
sweater        : 2426
swing set      : 2379
sword          : 2489
syringe        : 2654
t-shirt        : 2458
table          : 2561
teapot         : 2525
teddy-bear     : 3568
telephone      : 2559
television     : 2456
tennis racquet : 4605
tent           : 2589
The Eiffel Tower: 2668
The Great Wall of China: 3839
The Mona Lisa  : 2320
tiger          : 2382
toaster        : 2400
toe            : 2973
toilet         : 2559
tooth          : 2495
toothbrush     : 2478
toothpaste     : 2552
tornado        : 2821
tractor        : 2248
traffic light  : 2466
train          : 2559
tree           : 2732
triangle       : 2414
trombone       : 3699
truck          : 2603
trumpet        : 3398
umbrella       : 2510
underwear      : 2502
van            : 3241
vase           : 2455
violin         : 4292
washing machine: 2434
watermelon     : 2609
waterslide     : 3768
whale          : 2301
wheel          : 2708
windmill       : 2360
wine bottle    : 2473
wine glass     : 2712
wristwatch     : 3311
yoga           : 5547
zebra          : 2917
zigzag         : 2319
In [385]:
mantions_x = [class_index for class_index, class_name in enumerate(classes)]
mantions_bars = [mentions[class_name] for class_name in mentions]

plt.bar(mantions_x, mantions_bars)
plt.xlabel('Class index')
plt.ylabel('Items per class')
plt.show()

Create model

In [386]:
model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Convolution2D(
    input_shape=image_shape,
    kernel_size=5,
    filters=32,
    padding='same',
    activation=tf.keras.activations.relu
))
model.add(tf.keras.layers.MaxPooling2D(
    pool_size=2,
    strides=2
))

model.add(tf.keras.layers.Convolution2D(
    kernel_size=3,
    filters=32,
    padding='same',
    activation=tf.keras.activations.relu,
))
model.add(tf.keras.layers.MaxPooling2D(
    pool_size=2,
    strides=2
))

model.add(tf.keras.layers.Convolution2D(
    kernel_size=3,
    filters=64,
    padding='same',
    activation=tf.keras.activations.relu
))
model.add(tf.keras.layers.MaxPooling2D(
    pool_size=2,
    strides=2
))

model.add(tf.keras.layers.Flatten())

model.add(tf.keras.layers.Dense(
    units=512,
    activation=tf.keras.activations.relu
))

model.add(tf.keras.layers.Dense(
    units=num_classes,
    activation=tf.keras.activations.softmax
))
In [387]:
model.summary()
Model: "sequential_19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_57 (Conv2D)           (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d_57 (MaxPooling (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 14, 14, 32)        9248      
_________________________________________________________________
max_pooling2d_58 (MaxPooling (None, 7, 7, 32)          0         
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 7, 7, 64)          18496     
_________________________________________________________________
max_pooling2d_59 (MaxPooling (None, 3, 3, 64)          0         
_________________________________________________________________
flatten_19 (Flatten)         (None, 576)               0         
_________________________________________________________________
dense_43 (Dense)             (None, 512)               295424    
_________________________________________________________________
dense_44 (Dense)             (None, 345)               176985    
=================================================================
Total params: 500,985
Trainable params: 500,985
Non-trainable params: 0
_________________________________________________________________
In [388]:
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    show_layer_names=True,
)
Out[388]:
In [389]:
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.003)
rms_prop_optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

model.compile(
    optimizer=adam_optimizer,
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=['accuracy']
)

Train model

In [390]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    monitor='val_accuracy',
    restore_best_weights=True,
    verbose=1
)
In [ ]:
training_history = model.fit(
    x=dataset_train,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_data=dataset_val,
    callbacks=[
        early_stopping_callback
    ]
)
In [393]:
# Renders the charts for training accuracy and loss.
def render_training_history(training_history):
    loss = training_history.history['loss']
    val_loss = training_history.history['val_loss']

    accuracy = training_history.history['accuracy']
    val_accuracy = training_history.history['val_accuracy']

    plt.figure(figsize=(14, 4))

    plt.subplot(1, 2, 1)
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(loss, label='Training set')
    plt.plot(val_loss, label='Test set', linestyle='--')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.subplot(1, 2, 2)
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.plot(accuracy, label='Training set')
    plt.plot(val_accuracy, label='Test set', linestyle='--')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.show()
In [394]:
render_training_history(training_history)

Evaluate model accuracy

Training set accuracy

In [395]:
%%capture
train_loss, train_accuracy = model.evaluate(dataset_train.take(1))
In [396]:
print('Train loss: ', '{:.2f}'.format(train_loss))
print('Train accuracy: ', '{:.2f}'.format(train_accuracy))
Train loss:  1.36
Train accuracy:  0.66

Validation set accuracy

In [397]:
%%capture
val_loss, val_accuracy = model.evaluate(dataset_val)
In [398]:
print('Validation loss: ', '{:.2f}'.format(val_loss))
print('Validation accuracy: ', '{:.2f}'.format(val_accuracy))
Validation loss:  1.43
Validation accuracy:  0.65

Test set accuracy

In [399]:
%%capture
test_loss, test_accuracy = model.evaluate(dataset_test)
In [400]:
print('Test loss: ', '{:.2f}'.format(test_loss))
print('Test accuracy: ', '{:.2f}'.format(test_accuracy))
Test loss:  1.40
Test accuracy:  0.67

Visualizing Predictions

In [422]:
def visualize_predictions(model, dataset):
    numbers_to_display = 64
    num_cells = math.ceil(math.sqrt(numbers_to_display))
    plt.figure(figsize=(15, 15))
    
    batch = dataset.take(1)
    predictions = tf.math.argmax(model.predict(batch), axis=1).numpy()

    for x, y in batch:
        for image_index in range(numbers_to_display):
            pixels = np.reshape(x[image_index].numpy(), (28, 28))
            y_correct = tf.math.argmax(y[image_index]).numpy()
            y_predicted = predictions[image_index]
            correct_label = classes[y_correct]
            predicted_label = classes[y_predicted]
            
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            color_map = 'Greens' if y_correct == y_predicted else 'Reds'
            plt.subplot(num_cells, num_cells, image_index + 1)
            plt.imshow(pixels, cmap=color_map)
            plt.xlabel(correct_label + ' --> ' + predicted_label)
            
    plt.subplots_adjust(hspace=1, wspace=0.5)
    plt.show()
In [423]:
visualize_predictions(model, dataset_train)
In [424]:
visualize_predictions(model, dataset_test)