In this tutorial, you will learn how to train a neural network using transfer learning with the skorch
API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in PyTorch's Transfer Learning Tutorial to skorch
.
We will be using torchvision
for this tutorial. Instructions on how to install torchvision
for your platform can be found at https://pytorch.org.
Run in Google Colab | View source on GitHub |
Note: If you are running this in a colab notebook, we recommend you enable a free GPU by going:
Runtime → Change runtime type → Hardware Accelerator: GPU
If you are running in colab, you should install the dependencies and download the dataset by running the following cell:
import subprocess
# Installation on Google Colab
try:
import os
import google.colab
subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'torchvision'])
subprocess.run(['mkdir', '-p', 'datasets'])
subprocess.run(['wget', '-nc', '--no-check-certificate', 'https://download.pytorch.org/tutorial/hymenoptera_data.zip', '-P', 'datasets'])
subprocess.run(['unzip', '-u', 'datasets/hymenoptera_data.zip', '-d' 'datasets'])
except ImportError:
pass
import os
from urllib import request
from zipfile import ZipFile
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms
from skorch import NeuralNetClassifier
from skorch.helper import predefined_split
torch.manual_seed(360);
Before we begin, lets download the data needed for this tutorial:
def download_and_extract_data(dataset_dir='datasets'):
data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')
data_path = os.path.join(dataset_dir, 'hymenoptera_data')
url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
if not os.path.exists(data_path):
if not os.path.exists(data_zip):
print("Starting to download data...")
data = request.urlopen(url, timeout=15).read()
with open(data_zip, 'wb') as f:
f.write(data)
print("Starting to extract data...")
with ZipFile(data_zip, 'r') as zip_f:
zip_f.extractall(dataset_dir)
print("Data has been downloaded and extracted to {}.".format(dataset_dir))
download_and_extract_data()
Data has been downloaded and extracted to datasets.
We are going to train a neural network to classify ants and bees. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:
data_dir = 'datasets/hymenoptera_data'
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
train_ds = datasets.ImageFolder(
os.path.join(data_dir, 'train'), train_transforms)
val_ds = datasets.ImageFolder(
os.path.join(data_dir, 'val'), val_transforms)
The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: [0.485, 0.456, 0.406]
, and standard deviation: [0.229, 0.224, 0.225]
. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet.
We use a pretrained ResNet18
neural network model with its final layer replaced with a fully connected layer:
class PretrainedModel(nn.Module):
def __init__(self, output_features):
super().__init__()
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, output_features)
self.model = model
def forward(self, x):
return self.model(x)
Since we are training a binary classifier, the output of the final fully connected layer has size 2.
In this section, we will create a skorch.NeuralNetClassifier
to solve our classification problem.
First, we create a LRScheduler
callback which is a learning rate scheduler that uses torch.optim.lr_scheduler.StepLR
to scale learning rates by gamma=0.1
every 7 steps:
from skorch.callbacks import LRScheduler
lrscheduler = LRScheduler(
policy='StepLR', step_size=7, gamma=0.1)
Next, we create a Checkpoint
callback which saves the best model by by monitoring the validation accuracy.
from skorch.callbacks import Checkpoint
checkpoint = Checkpoint(
f_params='best_model.pt', monitor='valid_acc_best')
Lastly, we create a Freezer
to freeze all weights besides the final layer named model.fc
:
from skorch.callbacks import Freezer
freezer = Freezer(lambda x: not x.startswith('model.fc'))
With all the preparations out of the way, we can now define our NeuralNetClassifier
:
net = NeuralNetClassifier(
PretrainedModel,
criterion=nn.CrossEntropyLoss,
lr=0.001,
batch_size=4,
max_epochs=25,
module__output_features=2,
optimizer=optim.SGD,
optimizer__momentum=0.9,
iterator_train__shuffle=True,
iterator_train__num_workers=2,
iterator_valid__num_workers=2,
train_split=predefined_split(val_ds),
callbacks=[lrscheduler, checkpoint, freezer],
device='cuda' # comment to train on cpu
)
That is quite a few parameters! Lets walk through each one:
model_ft
: Our ResNet18
neural networkcriterion=nn.CrossEntropyLoss
: loss functionlr
: Initial learning ratebatch_size
: Size of a batchmax_epochs
: Number of epochs to trainmodule__output_features
: Used by __init__
in our PretrainedModel
class to set the number of classes.optimizer
: Our optimizeroptimizer__momentum
: The initial momentumiterator_{train,valid}__{shuffle,num_workers}
: Parameters that are passed to the dataloader.train_split
: A wrapper around val_ds
to use our validation dataset.callbacks
: Our callbacksdevice
: Set to cuda
to train on gpu.Now we are ready to train our neural network:
net.fit(train_ds, y=None);
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
epoch train_loss valid_acc valid_loss cp lr dur ------- ------------ ----------- ------------ ---- ------ ------ 1 0.6488 0.9477 0.1860 + 0.0010 9.3038 2 0.4275 0.9412 0.1697 0.0010 3.0520 3 0.4977 0.9346 0.1728 0.0010 3.1005 4 0.5072 0.9346 0.1766 0.0010 3.1506 5 0.5104 0.9608 0.1548 + 0.0010 3.4832 6 0.3861 0.9216 0.1879 0.0010 3.2256 7 0.4329 0.9346 0.1839 0.0010 3.0548 8 0.3634 0.9477 0.1604 0.0001 3.3032 9 0.3625 0.9477 0.1606 0.0001 3.0581 10 0.3444 0.9412 0.1796 0.0001 3.0689 11 0.3334 0.9346 0.1904 0.0001 3.1114 12 0.3719 0.9477 0.1637 0.0001 3.0940 13 0.4330 0.9412 0.1616 0.0001 3.0920 14 0.2887 0.9477 0.1632 0.0001 3.0811 15 0.2981 0.9477 0.1682 0.0000 3.1915 16 0.3129 0.9477 0.1665 0.0000 3.0351 17 0.3422 0.9412 0.1983 0.0000 3.0565 18 0.4063 0.9412 0.1629 0.0000 4.8236 19 0.3207 0.9412 0.1796 0.0000 3.6040 20 0.3319 0.9477 0.1549 0.0000 3.0968 21 0.2872 0.9412 0.1658 0.0000 3.1563 22 0.3190 0.9346 0.1716 0.0000 3.1373 23 0.3513 0.9477 0.1720 0.0000 3.1291 24 0.3579 0.9608 0.1587 0.0000 3.1853 25 0.3300 0.9542 0.1639 0.0000 3.0743
The best model is stored at best_model.pt
, with a validiation accuracy of roughly 0.96.
Congrualations! You now know how to finetune a neural network using skorch
. Feel free to explore the other tutorials to learn more about using skorch
.