fast.ai is a machine learning library built on top of PyTorch. Keras is also a machine learning framework, but it can sit on top of several different frameworks.
In this example it sits on top of tensorflow.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
# First we set up our data paths, like normal
PATH = "data/dogscats/"
sz=224
batch_size=64
# Then we import keras
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.layers import Dropout, Flatten, Dense
from keras.applications import ResNet50
from keras.models import Model, Sequential
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
from keras.applications.resnet50 import preprocess_input
/home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters Using TensorFlow backend. /home/paperspace/anaconda3/envs/fastai/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6 return f(*args, **kwds)
# We must set up paths to our training data and our validation data
train_data_dir = f'{PATH}train'
validation_data_dir = f'{PATH}valid'
# Following that, we have to set up "data generators" for both our training and test (validation) data.
# These specify the transformations and normalization required
train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input,
shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
# We can now declare our data generators.
# class_mode='binary' means we are going to classify between 2 different classes.
# Otherwise we could use categorical
train_generator = train_datagen.flow_from_directory(train_data_dir,
target_size=(sz, sz),
batch_size=batch_size, class_mode='binary')
# Note here the shuffle parameter. For training, we shuffle the input to make things as random as possible
# But for validation, we want to keep the order of the images so we don't mess up the labels
validation_generator = test_datagen.flow_from_directory(validation_data_dir,
shuffle=False,
target_size=(sz, sz),
batch_size=batch_size, class_mode='binary')
Found 23000 images belonging to 2 classes. Found 2000 images belonging to 2 classes.
# Now we create a base model and declare the different layers we want to use
base_model = ResNet50(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
# Having set up our base model, we set up our full model
model = Model(inputs=base_model.input, outputs=predictions)
# We have to manually freeze the layers one-by-one
for layer in base_model.layers: layer.trainable = False
# Models must be compiled before they can be used
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
# Do the training
# (Interrupted due to time)
%%time
model.fit_generator(train_generator, train_generator.n // batch_size, epochs=3, workers=4,
validation_data=validation_generator, validation_steps=validation_generator.n // batch_size)
Epoch 1/3 128/359 [=========>....................] - ETA: 35:29 - loss: 0.2981 - acc: 0.9271
--------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <timed eval> in <module>() ~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs) 89 warnings.warn('Update your `' + object_name + 90 '` call to the Keras 2 API: ' + signature, stacklevel=2) ---> 91 return func(*args, **kwargs) 92 wrapper._original_function = func 93 return wrapper ~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 2175 outs = self.train_on_batch(x, y, 2176 sample_weight=sample_weight, -> 2177 class_weight=class_weight) 2178 2179 if not isinstance(outs, list): ~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight) 1847 ins = x + y + sample_weights 1848 self._make_train_function() -> 1849 outputs = self.train_function(ins) 1850 if len(outputs) == 1: 1851 return outputs[0] ~/anaconda3/envs/fastai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs) 2473 session = get_session() 2474 updated = session.run(fetches=fetches, feed_dict=feed_dict, -> 2475 **self.session_kwargs) 2476 return updated[:len(self.outputs)] 2477 ~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 887 try: 888 result = self._run(None, fetches, feed_dict, options_ptr, --> 889 run_metadata_ptr) 890 if run_metadata: 891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) ~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1118 if final_fetches or final_targets or (handle and feed_dict_tensor): 1119 results = self._do_run(handle, final_targets, final_fetches, -> 1120 feed_dict_tensor, options, run_metadata) 1121 else: 1122 results = [] ~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1315 if handle is None: 1316 return self._do_call(_run_fn, self._session, feeds, fetches, targets, -> 1317 options, run_metadata) 1318 else: 1319 return self._do_call(_prun_fn, self._session, handle, feeds, fetches) ~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1321 def _do_call(self, fn, *args): 1322 try: -> 1323 return fn(*args) 1324 except errors.OpError as e: 1325 message = compat.as_text(e.message) ~/anaconda3/envs/fastai/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1300 return tf_session.TF_Run(session, options, 1301 feed_dict, fetch_list, target_list, -> 1302 status, run_metadata) 1303 1304 def _prun_fn(session, handle, feed_dict, fetch_list): KeyboardInterrupt:
# There is no support for differential learning rates
# So we split the model at 140, freeze all before, and unfreeze all layers after
# Then train again
split_at = 140
for layer in model.layers[:split_at]: layer.trainable = False
for layer in model.layers[split_at:]: layer.trainable = True
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
%%time
model.fit_generator(train_generator, train_generator.n // batch_size, epochs=1, workers=3,
validation_data=validation_generator, validation_steps=validation_generator.n // batch_size)
Epoch 1/1 359/359 [==============================] - 217s 603ms/step - loss: 0.0762 - acc: 0.9741 - val_loss: 7.9436 - val_acc: 0.5055 CPU times: user 4min 47s, sys: 14.1 s, total: 5min 1s Wall time: 3min 38s
<keras.callbacks.History at 0x7f22d8428f28>