Lesson 3: 'Dogs vs Cats' using Keras

fast.ai is a machine learning library built on top of PyTorch. Keras is also a machine learning framework, but it can sit on top of several different frameworks.

In this example it sits on top of tensorflow.

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
# First we set up our data paths, like normal
PATH = "data/dogscats/"
In [3]:
# Then we import keras
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.layers import Dropout, Flatten, Dense
from keras.applications import ResNet50
from keras.models import Model, Sequential
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
from keras.applications.resnet50 import preprocess_input
/home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
/home/paperspace/anaconda3/envs/fastai/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
  return f(*args, **kwds)
In [4]:
# We must set up paths to our training data and our validation data
train_data_dir = f'{PATH}train'
validation_data_dir = f'{PATH}valid'
In [5]:
# Following that, we have to set up "data generators" for both our training and test (validation) data.
# These specify the transformations and normalization required

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input,
    shear_range=0.2, zoom_range=0.2, horizontal_flip=True)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

# We can now declare our data generators.
# class_mode='binary' means we are going to classify between 2 different classes.
# Otherwise we could use categorical
train_generator = train_datagen.flow_from_directory(train_data_dir,
    target_size=(sz, sz),
    batch_size=batch_size, class_mode='binary')

# Note here the shuffle parameter. For training, we shuffle the input to make things as random as possible
# But for validation, we want to keep the order of the images so we don't mess up the labels
validation_generator = test_datagen.flow_from_directory(validation_data_dir,
    target_size=(sz, sz),
    batch_size=batch_size, class_mode='binary')
Found 23000 images belonging to 2 classes.
Found 2000 images belonging to 2 classes.
In [6]:
# Now we create a base model and declare the different layers we want to use
base_model = ResNet50(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
In [7]:
# Having set up our base model, we set up our full model
model = Model(inputs=base_model.input, outputs=predictions)

# We have to manually freeze the layers one-by-one
for layer in base_model.layers: layer.trainable = False

# Models must be compiled before they can be used    
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
In [8]:
# Do the training
# (Interrupted due to time)
model.fit_generator(train_generator, train_generator.n // batch_size, epochs=3, workers=4,
        validation_data=validation_generator, validation_steps=validation_generator.n // batch_size)
Epoch 1/3
128/359 [=========>....................] - ETA: 35:29 - loss: 0.2981 - acc: 0.9271
KeyboardInterrupt                   Traceback (most recent call last)
<timed eval> in <module>()

~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2175                     outs = self.train_on_batch(x, y,
   2176                                                sample_weight=sample_weight,
-> 2177                                                class_weight=class_weight)
   2179                     if not isinstance(outs, list):

~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1847             ins = x + y + sample_weights
   1848         self._make_train_function()
-> 1849         outputs = self.train_function(ins)
   1850         if len(outputs) == 1:
   1851             return outputs[0]

~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2473         session = get_session()
   2474         updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2475                               **self.session_kwargs)
   2476         return updated[:len(self.outputs)]

~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1321   def _do_call(self, fn, *args):
   1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:
   1325       message = compat.as_text(e.message)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1300           return tf_session.TF_Run(session, options,
   1301                                    feed_dict, fetch_list, target_list,
-> 1302                                    status, run_metadata)
   1304     def _prun_fn(session, handle, feed_dict, fetch_list):

In [9]:
# There is no support for differential learning rates
# So we split the model at 140, freeze all before, and unfreeze all layers after
# Then train again
split_at = 140
for layer in model.layers[:split_at]: layer.trainable = False
for layer in model.layers[split_at:]: layer.trainable = True
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])    
In [10]:
model.fit_generator(train_generator, train_generator.n // batch_size, epochs=1, workers=3,
        validation_data=validation_generator, validation_steps=validation_generator.n // batch_size)
Epoch 1/1
359/359 [==============================] - 217s 603ms/step - loss: 0.0762 - acc: 0.9741 - val_loss: 7.9436 - val_acc: 0.5055
CPU times: user 4min 47s, sys: 14.1 s, total: 5min 1s
Wall time: 3min 38s
<keras.callbacks.History at 0x7f22d8428f28>