In this guide we will design and train an image recognition model for recognizing 17 different species of flowers. We will use a technique called transfer learning, combining a predefined model called VGG19 trained on ImageNet with our own flower classification subnet. The guide is an extension of the workshop found here, and thus requires some shallow understanding of machine learning theory and programming in Python. Running the code has two prerequisites:
A python environment containing tensorflow and keras:
import tensorflow
import keras
# If you get a FutureWarning from running this cell,
# don't worry, everything is as it should be!
Using TensorFlow backend.
Downloading and restructuring the dataset in a folder 'flowers' as defined here:
import os
print('Root dir exists: {}'.format(os.path.isdir('flowers')))
print('Train dir exists: {}'.format(
os.path.isdir(os.path.join('flowers', 'train'))))
print('Validation dir exists: {}'.format(
os.path.isdir(os.path.join('flowers', 'val'))))
Root dir exists: True Train dir exists: True Validation dir exists: True
To serve images to our model during training we will use a python Generator. Given that our images are structured as previously stated, a folder with the training set and a folder with the validation set, both with subfolders for each category of flowers, we can use a prebuilt keras generator called ImageDataGenerator. The generator acts as a list and serves what is known as batches of tuples, where the first element of the tuple contains the images and of the batch and the second element contains the corresponding labels.
from keras.preprocessing.image import ImageDataGenerator
generator = ImageDataGenerator()
batches = generator.flow_from_directory('flowers/train', batch_size=4)
batches
Found 1105 images belonging to 17 classes.
<keras_preprocessing.image.directory_iterator.DirectoryIterator at 0x11360a470>
At this point it is usually a good idea to do a sanity check. This typically includes verifying that our images are served on the correct format, and that the images and labels are still correctly matched. We can first fetch the reverse encoding of the generator to be able to decode the onehot encoded labels given in the batches
indices = batches.class_indices
labels = [None] * 17
for key in indices:
labels[indices[key]] = key
labels
['bluebell', 'buttercup', 'colts_foot', 'cowslip', 'crocus', 'daffodil', 'daisy', 'dandelion', 'fritillary', 'iris', 'lily_valley', 'pansy', 'snowdrop', 'sunflower', 'tigerlily', 'tulip', 'windflower']
We can then use matplotlib to visualize the first batch
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
for X, y in batches:
fig, ax = plt.subplots(1, 4, figsize=(10, 10))
for i in range(len(X)):
img = X[i].astype(np.uint8)
label = labels[np.argmax(y[i])]
ax[i].imshow(img)
ax[i].set_title(label)
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.show()
break # We only need the first batch
Once we know our dataset is being served correctly we can start setting up the base model that will be the core of our flower classification model. As previously mentioned this will be a model called VGG19, a small model which yields relatively good results. Like the generator, this also exists as a prebuilt module in Keras, namely in the applications-module.
from keras.applications.vgg19 import VGG19
When initializing the model we need to specify that we want to use the weights trained on ImageNet, that we want the entire model including top layers, and we also specify the image size we are going to use for verbosity.
model = VGG19(weights='imagenet', include_top=True,
input_shape=(224, 224, 3))
model.summary()
WARNING:tensorflow:From /Users/esten/miniconda3/envs/test/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv4 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ flatten (Flatten) (None, 25088) 0 _________________________________________________________________ fc1 (Dense) (None, 4096) 102764544 _________________________________________________________________ fc2 (Dense) (None, 4096) 16781312 _________________________________________________________________ predictions (Dense) (None, 1000) 4097000 ================================================================= Total params: 143,667,240 Trainable params: 143,667,240 Non-trainable params: 0 _________________________________________________________________
We can sanity check this step by running predicting the label for an image from our generator. Note that the predictions we are doing now will be using the labels from ImageNet, as this is what the model currently recognizes, not the labels from our dataset. We start by reinitializing the generator with the correct size and a batch size of 1. We also set the seed for the random library to control the order of the images
np.random.seed(1234)
generator = ImageDataGenerator()
batches = generator.flow_from_directory('flowers/train',
target_size=(224, 224),
batch_size=1)
Found 1105 images belonging to 17 classes.
We can then run predictions on the first batch containing a single image. To decode the prediction using imagenet labels we can use a predefined function found in the same module as the model
from keras.applications.vgg19 import decode_predictions
for X, y in batches:
preds = model.predict(X)
decoded_preds = decode_predictions(preds, top=1)
fig = plt.figure()
img = X[0].astype(np.uint8)
label = labels[np.argmax(y[0])]
predicted = decoded_preds[0]
plt.imshow(img)
fig.suptitle('Truth: {}, Predicted: {}'.format(label, predicted))
plt.show()
break
We now know both our generator and model are set up, and we are able to make predictions. The predictions, however, does not necessarily look very good. This is because of a process called preprocessing: A set of transformations applied to the images before training to give the model the best possible foundation to learn what it needs. Typical preprocessing includes rescaling the values of the data, shifting the range, and other numerical operations. Luckily, in Keras, the module which contains a model also contains the preprocessing function used for training the model. We can fetch this function and feed it to our generator to ensure all images are preprocessed before they are served to the model
from keras.applications.vgg19 import preprocess_input
np.random.seed(1234)
generator = ImageDataGenerator(preprocessing_function=preprocess_input)
batches = generator.flow_from_directory('flowers/train',
target_size=(224, 224),
batch_size=1)
Found 1105 images belonging to 17 classes.
Once we have reinitialized the generator correctly we can rerun our predictions to see if they improve
for X, y in batches:
preds = model.predict(X)
decoded_preds = decode_predictions(preds, top=1)
fig = plt.figure()
img = X[0].astype(np.uint8)
label = labels[np.argmax(y[0])]
predicted = decoded_preds[0]
plt.imshow(img)
fig.suptitle('Truth: {}, Predicted: {}'.format(label, predicted))
plt.show()
break
Note that seeing improvements in the predictions is not a given even though the images are now preprocessed correctly. The label we are looking for might not be a part of the original dataset the model was trained on, or it might simply be a case of a bad prediction where the model misses. However, running a sanity check (preferably over more images) is usually a good habit to achieve the best results.
Once we are happy with the interaction of our base model, we can start setting up our own custom model for solving the problem we are interested in, in this case classifying flower species. The first step is to be a bit more restrictive with what we use from the pretrained model, only picking out the parts we need. We do this by dropping the top layers used for predictions, and instead perform a pooling operation on the final convolutional layer. Once we have initialized it we can fetch the input and the output of the pretrained model using properties found in keras' model class.
pretrained = VGG19(include_top=False, input_shape=(224, 224, 3),
weights='imagenet', pooling='max')
inputs = pretrained.input
outputs = pretrained.output
As we do not want the weights in this part of the final model to change, we can freeze them
for layer in pretrained.layers:
layer.trainable = False
We can then create our own custom layers for performing our own task. We will use a hidden fully connected layer with 128 neurons, and a final prediction layer with 17 neurons, one per specie in our dataset. Note that the hidden layer takes the output from the pretrained model as its input.
from keras.layers import Dense
hidden = Dense(128, activation='relu')(outputs)
preds = Dense(17, activation='softmax')(hidden)
Once we have all our layers set up we can wrap them in a Model, and compile the model using a pretty standardized set of hyperparameters.
from keras.engine import Model
from keras.optimizers import Adam
model = Model(inputs, preds)
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=1e-4),
metrics=['acc'])
model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv4 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ global_max_pooling2d_1 (Glob (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 128) 65664 _________________________________________________________________ dense_2 (Dense) (None, 17) 2193 ================================================================= Total params: 20,092,241 Trainable params: 67,857 Non-trainable params: 20,024,384 _________________________________________________________________
We can then set up generators like we did before, one for the training data and one for validation, and train the model using fit_generator. Note that training here is set to a single epoch simply for example purposes
np.random.seed(1234)
# If you run into memory errors, try reducing this
batch_size = 32
train_generator = ImageDataGenerator(
preprocessing_function=preprocess_input)
train_batches = train_generator.flow_from_directory('flowers/train',
target_size=(224, 224),
batch_size=batch_size)
val_generator = ImageDataGenerator(
preprocessing_function=preprocess_input)
val_batches = val_generator.flow_from_directory('flowers/val',
target_size=(224, 224),
batch_size=batch_size)
# Note that training is set to 1 epoch,
# to avoid unintentionally locking up computers
model.fit_generator(train_batches,
epochs=1,
validation_data=val_batches,
steps_per_epoch=len(train_batches),
validation_steps=len(val_batches))
Found 1105 images belonging to 17 classes. Found 255 images belonging to 17 classes. WARNING:tensorflow:From /Users/esten/miniconda3/envs/test/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead. Epoch 1/1 35/35 [==============================] - 742s 21s/step - loss: 14.3701 - acc: 0.0739 - val_loss: 14.0535 - val_acc: 0.0667
<keras.callbacks.History at 0x12ac06438>
If you run the training above for a larger number of epochs, you will typically achieve a very decent result on the training data and a considerably worse outcome on the validation data. This is an example of overfitting: The model starts remembering specifics from the training set instead of learning the general features we are interested in. We handle this by introducing regularization, trying to force the model to generalize. In image recognition this is typically done using dropout-layers, which during training randomly sets the firing of a subset of neurons to 0. We can introduce dropout to our model by inserting a Dropout layer which drops 30% of the neurons between the two final layers of our model.
from keras.layers import Dropout
hidden = Dense(128, activation='relu')(outputs)
dropout = Dropout(.3)(hidden)
preds = Dense(17, activation='softmax')(dropout)
model = Model(inputs, preds)
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=1e-4),
metrics=['acc'])
model.summary()
WARNING:tensorflow:From /Users/esten/miniconda3/envs/test/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`. _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv4 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ global_max_pooling2d_1 (Glob (None, 512) 0 _________________________________________________________________ dense_3 (Dense) (None, 128) 65664 _________________________________________________________________ dropout_1 (Dropout) (None, 128) 0 _________________________________________________________________ dense_4 (Dense) (None, 17) 2193 ================================================================= Total params: 20,092,241 Trainable params: 67,857 Non-trainable params: 20,024,384 _________________________________________________________________
We can recompile the model and restart training to achieve what should be a better result
np.random.seed(1234)
# If you run into memory errors, try reducing this
batch_size = 32
train_generator = ImageDataGenerator(
preprocessing_function=preprocess_input)
train_batches = train_generator.flow_from_directory('flowers/train',
target_size=(224, 224),
batch_size=batch_size)
val_generator = ImageDataGenerator(
preprocessing_function=preprocess_input)
val_batches = val_generator.flow_from_directory('flowers/val',
target_size=(224, 224),
batch_size=batch_size)
# Note that training is set to 1 epoch,
# to avoid unintentionally locking up computers
model.fit_generator(train_batches,
epochs=1,
validation_data=val_batches,
steps_per_epoch=len(train_batches),
validation_steps=len(val_batches))
Found 1105 images belonging to 17 classes. Found 255 images belonging to 17 classes. Epoch 1/1 35/35 [==============================] - 703s 20s/step - loss: 14.7023 - acc: 0.0625 - val_loss: 14.3602 - val_acc: 0.0706
<keras.callbacks.History at 0x129c17630>
A second technique for avoiding overfitting is augmenting the images, which goal it is to take the existing data points in our dataset and create brand new samples. It works by somehow modifying an image in a way which changes it, while maintaining the thruthfulness of the corresponding label. An example in our case is mirroring the images vertically, which can be implemented directly in the keras generator
np.random.seed(1234)
generator = ImageDataGenerator(horizontal_flip=True)
batches = generator.flow_from_directory('flowers/train',
batch_size=1,
shuffle=False)
Found 1105 images belonging to 17 classes.
Using this functionality will randomly decide whether to flip the image or not each time the image is presented, theoretically yielding two samples from the single data point we started with. We can see this by visualizing the same image served from multiple batches
fig, ax = plt.subplots(1, 5, figsize=(15, 10))
for i in range(5):
batches = generator.flow_from_directory('flowers/train',
batch_size=1,
shuffle=False)
for X, y in batches:
ax[i].imshow(X[0].astype(np.uint8))
ax[i].set_title('Run {}'.format(i + 1))
ax[i].set_xticks([])
ax[i].set_yticks([])
break
plt.show()
Found 1105 images belonging to 17 classes. Found 1105 images belonging to 17 classes. Found 1105 images belonging to 17 classes. Found 1105 images belonging to 17 classes. Found 1105 images belonging to 17 classes.
Retraining the model with a set of augmentations should increase the accuracy even further. From now on, you can experiment with various parameters given to the ImageDataGEnerator class, and train the model by your self.
The steps in this guide provide a good starting point for classifying species of flowers, or solving any other generic image classification problem. Retracing the steps while leaving more epochs for the model to train should provide a solid baseline with decent results (my best run achieved ~75%). Continued work on this problem would typically include trying different architectures as core models, experimenting with various designs for the custom problem-specific final layers and testing a wide range of combinations of regularization and augmentations to combat overfitting. It should be a feasible goal to reach an accuracy in the high 90s, which seem to be how the state-of-the-art models are performing. Happy hacking!
Do you wan't more? You can ckeck out our guide of how to make a variational autoencoder. It's a bit more advanced, and require you to study a little bit on your own.