In this tutorial, you will learn how to train a neutral network using transfer learning with the skorch
API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in PyTorch's Transfer Learning Tutorial to skorch
.
We will be using torchvision
for this tutorial. Instructions on how to install torchvision
for your platform can be found at https://pytorch.org.
import os
from urllib import request
from zipfile import ZipFile
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, Checkpoint
from skorch.helper import filtered_optimizer
from skorch.helper import filter_requires_grad
from skorch.helper import predefined_split
torch.manual_seed(360);
Before we begin, lets download the data needed for this tutorial:
def download_and_extract_data(dataset_dir='datasets'):
data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')
data_path = os.path.join(dataset_dir, 'hymenoptera_data')
url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
if not os.path.exists(data_path):
if not os.path.exists(data_zip):
print("Starting to download data...")
data = request.urlopen(url, timeout=15).read()
with open(data_path, 'wb') as f:
f.write(data)
print("Starting to extract data...")
with ZipFile(data_zip, 'r') as zip_f:
zip_f.extractall(dataset_dir)
print("Data has been downloaded and extracted to {}.".format(dataset_dir))
download_and_extract_data()
Data has been downloaded and extracted to datasets.
We are going to train a neutral network to classify ants and bees. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:
data_dir = 'datasets/hymenoptera_data'
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
train_ds = datasets.ImageFolder(
os.path.join(data_dir, 'train'), train_transforms)
val_ds = datasets.ImageFolder(
os.path.join(data_dir, 'val'), val_transforms)
The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: [0.485, 0.456, 0.406]
, and standard deviation: [0.229, 0.224, 0.225]
. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet.
We use a pretrained ResNet18
neutral network model with its final layer replaced with a fully connected layer:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
Since we are training a binary classifier, the output of the final fully connected layer has size 2. Next, we freeze all layers except the final layer by setting requires_grad
to False:
for name, param in model_ft.named_parameters():
if not name.startswith('fc'):
param.requires_grad_(False)
In this section, we will create a skorch.NeuralNetClassifier
to solve our classification problem.
First, we create two callbacks:
lrscheduler = LRScheduler(
policy='StepLR', step_size=7, gamma=0.1)
checkpoint = Checkpoint(
f_params='best_model.pt', monitor='valid_acc_best')
callbacks = [lrscheduler, checkpoint]
The LRScheduler
callback defines a learning rate scheduler that uses torch.optim.lr_scheduler.StepLR
to scale learning rates by gamma=0.1
every 7 steps. The Checkpoint
callback saves the best model by by monitoring the validation accuracy.
Since we froze some layers in our Resnet18
neutral network, we need to configure our optimizer to only update gradients in our final fully connected layer. Luckily, skorch
provides two functions that make this simple:
optimizer = filtered_optimizer(
optim.SGD, filter_requires_grad
)
This function does not do any processing and returns the two datasets.
With all the preparations out of the way, we can now define our NeutralNetClassifier
:
net = NeuralNetClassifier(
model_ft,
criterion=nn.CrossEntropyLoss,
lr=0.001,
batch_size=4,
max_epochs=25,
optimizer=optimizer,
optimizer__momentum=0.9,
iterator_train__shuffle=True,
iterator_train__num_workers=4,
iterator_valid__shuffle=True,
iterator_valid__num_workers=4,
train_split=predefined_split(val_ds),
callbacks=callbacks,
device='cuda' # uncomment to train on gpu
)
That is quite a few parameters! Lets walk through each one:
model_ft
: Our ResNet18
neutral networkcriterion=nn.CrossEntropyLoss
: loss functionlr
: Initial learning ratebatch_size
: Size of a batchmax_epochs
: Number of epochs to trainoptimizer
: Our filtered optimizeroptimizer__momentum
: The initial momentumiterator_{train,valid}__{shuffle,num_workers}
: Parameters that are passed to the dataloader.train_split
: A wrapper around val_ds
to use our validation dataset.callbacks
: Our callbacksdevice
: Set to cuda
to train on gpu.Now we are ready to train our neutral network:
net.fit(train_ds, y=None);
epoch train_loss valid_acc valid_loss cp dur ------- ------------ ----------- ------------ ---- ------ 1 0.8333 0.9346 0.2150 + 1.4656 2 0.5203 0.9150 0.2439 1.2813 3 0.4469 0.8693 0.3494 1.2942 4 0.4665 0.9542 0.1949 + 1.2363 5 0.3884 0.9412 0.1962 1.2834 6 0.3807 0.9412 0.1923 1.3064 7 0.3292 0.9412 0.1876 1.3428 8 0.2864 0.9412 0.1961 1.3132 9 0.4199 0.9346 0.1987 1.2935 10 0.4462 0.9412 0.2054 1.2886 11 0.2971 0.9412 0.1952 1.4172 12 0.3474 0.9412 0.2092 1.3306 13 0.2891 0.9412 0.2285 1.2747 14 0.3648 0.8889 0.2870 1.3012 15 0.3090 0.9346 0.2029 1.2845 16 0.3348 0.9281 0.2345 1.3291 17 0.2940 0.9412 0.1908 1.2621 18 0.3961 0.9412 0.2181 1.3413 19 0.3652 0.9281 0.2256 1.2868 20 0.3038 0.9346 0.2178 1.3177 21 0.3489 0.9477 0.2041 1.2754 22 0.3161 0.9477 0.1848 1.2837 23 0.2622 0.9477 0.2074 1.3112 24 0.3823 0.9477 0.1923 1.2328 25 0.2952 0.9412 0.2223 1.3896
The best model is stored at best_model.pt
, with a validiation accuracy of roughly 0.95.
Congrualations! You now know how to finetune a neutral network using skorch
. Feel free to explore the other tutorials to learn more about using skorch
.