This assumes we have already set up our directories and found the learning rate using the learning rate finder lr_find()
.
from fastai.conv_learner import *
PATH = "data/dogscats/"
size = 224
batch_size = 64
# Set up our transforms, data, and learner
# Do our initial learning for our final layer
# Note we do not call precompute=true. This is just an optimization that lets us some weights that have been computed before. We can always leave it out though.
transforms = tfms_from_model(resnet50, size, aug_tfms=transforms_side_on, max_zoom=1.1)
data = ImageClassifierData.from_paths(PATH, tfms=transforms, bs=batch_size)
learn = ConvLearner.pretrained(resnet50, data)
%time learn.fit(0.01, 3, cycle_len=1)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /home/paperspace/.torch/models/resnet50-19c8e357.pth 100%|██████████| 102502400/102502400 [00:02<00:00, 48204227.51it/s]
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
[0. 0.04423 0.02491 0.98877] [1. 0.0434 0.02632 0.98975] [2. 0.03641 0.02626 0.98877] CPU times: user 15min 19s, sys: 2min 43s, total: 18min 2s Wall time: 8min 17s
# Unfreeze the other layers and learn again using differential learning rates
# (Lower rates for more general layers)
learn.unfreeze()
# This will be explained later. General guideline: If using a model > 34, and similar data to image-net, use this.
# It causes batch normalization moving averages to not be updated
learn.bn_freeze(True)
# Do our learning again with differential learning rates
% time learn.fit([1e-5,1e-4,1e-2], 1, cycle_len=1)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
[0. 0.02143 0.02226 0.98975] CPU times: user 9min 35s, sys: 1min 41s, total: 11min 16s Wall time: 7min 55s
# Use test time augmentation to get our predictions
# Recall test time augmentation means transforming our test images to increase
# the chance that they are identified correctly
%time log_predictions,y = learn.TTA()
CPU times: user 1min 55s, sys: 20.8 s, total: 2min 16s Wall time: 1min
# y is an array of len 2000
y
array([0, 0, 0, ..., 1, 1, 1])
# 2000 images, can be in one of two classes, and we take a prediction for 1 image + 4 augmentations
log_predictions.shape
(5, 2000, 2)
np.exp(log_predictions)
array([[[0.99999, 0.00001], [0.9999 , 0.0001 ], [1. , 0. ], ..., [0.00016, 0.99984], [0.0002 , 0.9998 ], [0.00008, 0.99992]], [[0.99994, 0.00006], [0.99993, 0.00007], [0.99999, 0.00001], ..., [0.00003, 0.99997], [0.0001 , 0.9999 ], [0.00012, 0.99988]], [[0.99999, 0.00001], [0.99994, 0.00006], [0.99998, 0.00002], ..., [0.00025, 0.99975], [0.00003, 0.99997], [0.00005, 0.99995]], [[0.99999, 0.00001], [0.99962, 0.00038], [0.99999, 0.00001], ..., [0.00015, 0.99985], [0.00017, 0.99983], [0.00003, 0.99997]], [[0.99995, 0.00005], [0.99834, 0.00166], [1. , 0. ], ..., [0.00004, 0.99996], [0.0002 , 0.9998 ], [0.00007, 0.99993]]], dtype=float32)
??learn.TTA()
# Finally get our log loss and accuracy
# We convert the predictions from a log scale via np.exp
probabilities = np.mean(np.exp(log_predictions), 0)
metrics.log_loss(y, probabilities)
0.016303548492258166
accuracy(np.mean((log_predictions),0), y)
0.994