Fine-Tuning VGGNet with keras

This notebook demonstrates how to create a single pipeline to resize, preprocess, and augment images in real time during training with keras. This eliminates the need to preprocess images and write them back to disk before doing fine-tuning. This is accomplished by hooking together keras.preprocessing.image.ImageDataGenerator with keras.applications.vgg16.preprocess_input() via the former's (undocumented) preprocessing_function argument with a small wrapper function.

This notebook modifies this tutorial (full code here).

This approach can be generalized to all keras pretrained models. For example, for keras.applications.resnet50.ResNet50, you would simply swap out keras.applications.vgg16.preprocess_input() for keras.applications.resnet50.preprocess_input().

Prerequisites

This code assumes you have downloaded the kaggle cats vs. dogs dataset and have the following directory setup.

data/
    train/
        dogs/
            dog001.jpg
            dog002.jpg
            ...
        cats/
            cat001.jpg
            cat002.jpg
            ...
    validation/
        dogs/
            dog001.jpg
            dog002.jpg
            ...
        cats/
            cat001.jpg
            cat002.jpg
            ...

Load VGGNet

In [1]:
from keras.applications.vgg16 import VGG16
from keras.models import Model
from keras.layers import Dense

vgg16 = VGG16(weights='imagenet')

fc2 = vgg16.get_layer('fc2').output
prediction = Dense(output_dim=1, activation='sigmoid', name='logit')(fc2)
model = Model(input=vgg16.input, output=prediction)
model.summary()
Using Theano backend.
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 3, 224, 224)   0                                            
____________________________________________________________________________________________________
block1_conv1 (Convolution2D)     (None, 64, 224, 224)  1792        input_1[0][0]                    
____________________________________________________________________________________________________
block1_conv2 (Convolution2D)     (None, 64, 224, 224)  36928       block1_conv1[0][0]               
____________________________________________________________________________________________________
block1_pool (MaxPooling2D)       (None, 64, 112, 112)  0           block1_conv2[0][0]               
____________________________________________________________________________________________________
block2_conv1 (Convolution2D)     (None, 128, 112, 112) 73856       block1_pool[0][0]                
____________________________________________________________________________________________________
block2_conv2 (Convolution2D)     (None, 128, 112, 112) 147584      block2_conv1[0][0]               
____________________________________________________________________________________________________
block2_pool (MaxPooling2D)       (None, 128, 56, 56)   0           block2_conv2[0][0]               
____________________________________________________________________________________________________
block3_conv1 (Convolution2D)     (None, 256, 56, 56)   295168      block2_pool[0][0]                
____________________________________________________________________________________________________
block3_conv2 (Convolution2D)     (None, 256, 56, 56)   590080      block3_conv1[0][0]               
____________________________________________________________________________________________________
block3_conv3 (Convolution2D)     (None, 256, 56, 56)   590080      block3_conv2[0][0]               
____________________________________________________________________________________________________
block3_pool (MaxPooling2D)       (None, 256, 28, 28)   0           block3_conv3[0][0]               
____________________________________________________________________________________________________
block4_conv1 (Convolution2D)     (None, 512, 28, 28)   1180160     block3_pool[0][0]                
____________________________________________________________________________________________________
block4_conv2 (Convolution2D)     (None, 512, 28, 28)   2359808     block4_conv1[0][0]               
____________________________________________________________________________________________________
block4_conv3 (Convolution2D)     (None, 512, 28, 28)   2359808     block4_conv2[0][0]               
____________________________________________________________________________________________________
block4_pool (MaxPooling2D)       (None, 512, 14, 14)   0           block4_conv3[0][0]               
____________________________________________________________________________________________________
block5_conv1 (Convolution2D)     (None, 512, 14, 14)   2359808     block4_pool[0][0]                
____________________________________________________________________________________________________
block5_conv2 (Convolution2D)     (None, 512, 14, 14)   2359808     block5_conv1[0][0]               
____________________________________________________________________________________________________
block5_conv3 (Convolution2D)     (None, 512, 14, 14)   2359808     block5_conv2[0][0]               
____________________________________________________________________________________________________
block5_pool (MaxPooling2D)       (None, 512, 7, 7)     0           block5_conv3[0][0]               
____________________________________________________________________________________________________
flatten (Flatten)                (None, 25088)         0           block5_pool[0][0]                
____________________________________________________________________________________________________
fc1 (Dense)                      (None, 4096)          102764544   flatten[0][0]                    
____________________________________________________________________________________________________
fc2 (Dense)                      (None, 4096)          16781312    fc1[0][0]                        
____________________________________________________________________________________________________
logit (Dense)                    (None, 1)             4097        fc2[0][0]                        
====================================================================================================
Total params: 134,264,641
Trainable params: 134,264,641
Non-trainable params: 0
____________________________________________________________________________________________________

Freeze All Layers Except Bottleneck Layers for Fine-Tuning

In [2]:
import pandas as pd

for layer in model.layers:
    if layer.name in ['fc1', 'fc2', 'logit']:
        continue
    layer.trainable = False

df = pd.DataFrame(([layer.name, layer.trainable] for layer in model.layers), columns=['layer', 'trainable'])
df.style.applymap(lambda trainable: f'background-color: {"yellow" if trainable else "red"}', subset=['trainable'])
Out[2]:
layer trainable
0 input_1 False
1 block1_conv1 False
2 block1_conv2 False
3 block1_pool False
4 block2_conv1 False
5 block2_conv2 False
6 block2_pool False
7 block3_conv1 False
8 block3_conv2 False
9 block3_conv3 False
10 block3_pool False
11 block4_conv1 False
12 block4_conv2 False
13 block4_conv3 False
14 block4_pool False
15 block5_conv1 False
16 block5_conv2 False
17 block5_conv3 False
18 block5_pool False
19 flatten False
20 fc1 True
21 fc2 True
22 logit True

Compile with SGD Optimizer and a Small Learning Rate

In [3]:
from keras.optimizers import SGD

sgd = SGD(lr=1e-4, momentum=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])

Set Up Data Generators

Note the use of the preprocessing_function argument in keras.preprocessing.image.ImageDataGenerator.

In [4]:
import numpy as np
from keras.preprocessing.image import ImageDataGenerator, array_to_img

def preprocess_input_vgg(x):
    """Wrapper around keras.applications.vgg16.preprocess_input()
    to make it compatible for use with keras.preprocessing.image.ImageDataGenerator's
    `preprocessing_function` argument.
    
    Parameters
    ----------
    x : a numpy 3darray (a single image to be preprocessed)
    
    Note we cannot pass keras.applications.vgg16.preprocess_input()
    directly to to keras.preprocessing.image.ImageDataGenerator's
    `preprocessing_function` argument because the former expects a
    4D tensor whereas the latter expects a 3D tensor. Hence the
    existence of this wrapper.
    
    Returns a numpy 3darray (the preprocessed image).
    
    """
    from keras.applications.vgg16 import preprocess_input
    X = np.expand_dims(x, axis=0)
    X = preprocess_input(X)
    return X[0]

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input_vgg,
                                   rotation_range=40,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True,
                                   fill_mode='nearest')
train_generator = train_datagen.flow_from_directory(directory='data/train',
                                                    target_size=[224, 224],
                                                    batch_size=16,
                                                    class_mode='binary')

validation_datagen = ImageDataGenerator(preprocessing_function=preprocess_input_vgg)
validation_generator = validation_datagen.flow_from_directory(directory='data/test',
                                                              target_size=[224, 224],
                                                              batch_size=16,
                                                              class_mode='binary')
Found 2000 images belonging to 2 classes.
Found 300 images belonging to 2 classes.

Do Fine-Tuning

Since fine-tuning VGGNet is slow, we only perform a single epoch of training.

In [5]:
model.fit_generator(train_generator,
                    samples_per_epoch=16,
                    nb_epoch=1,
                    validation_data=validation_generator,
                    nb_val_samples=32);
Epoch 1/1
16/16 [==============================] - 23s - loss: 1.5019 - acc: 0.4375 - val_loss: 0.9128 - val_acc: 0.6875

Show Some Predictions

Note that the images look a little weird because their mean RGB values have been subtracted away.

In [8]:
from IPython.display import display
import matplotlib.pyplot as plt

X_val_sample, _ = next(validation_generator)
y_pred = model.predict(X_val_sample)

nb_sample = 4
for x, y in zip(X_val_sample[:nb_sample], y_pred.flatten()[:nb_sample]):
    s = pd.Series({'Cat': 1-y, 'Dog': y})
    axes = s.plot(kind='bar')
    axes.set_xlabel('Class')
    axes.set_ylabel('Probability')
    axes.set_ylim([0, 1])
    plt.show()

    img = array_to_img(x)
    display(img)