Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 torch 1.2.0
This notebook implements the classic LeNet-5 convolutional network [1] and applies it to the CIFAR10 object classification dataset. The basic architecture is shown in the figure below:
LeNet-5 is commonly regarded as the pioneer of convolutional neural networks, consisting of a very simple architecture (by modern standards). In total, LeNet-5 consists of only 7 layers. 3 out of these 7 layers are convolutional layers (C1, C3, C5), which are connected by two average pooling layers (S2 & S4). The penultimate layer is a fully connexted layer (F6), which is followed by the final output layer. The additional details are summarized below:
Please note that the original architecture was applied to MNIST-like grayscale images (1 color channel). CIFAR10 has 3 color-channels. I found that using the regular architecture results in very poor performance on CIFAR10 (approx. 50% ACC). Hence, I am multiplying the number of kernels by a factor of 3 (according to the 3 color channels) in each layer, which improves is a little bit (approx. 60% Acc).
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
##########################
### SETTINGS
##########################
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 128
NUM_EPOCHS = 10
# Architecture
NUM_FEATURES = 32*32
NUM_CLASSES = 10
# Other
DEVICE = "cuda:0"
GRAYSCALE = False
##########################
### CIFAR-10 Dataset
##########################
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.CIFAR10(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=8,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=8,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Files already downloaded and verified Image batch dimensions: torch.Size([128, 3, 32, 32]) Image label dimensions: torch.Size([128]) Image batch dimensions: torch.Size([128, 3, 32, 32]) Image label dimensions: torch.Size([128])
device = torch.device(DEVICE)
torch.manual_seed(0)
for epoch in range(2):
for batch_idx, (x, y) in enumerate(train_loader):
print('Epoch:', epoch+1, end='')
print(' | Batch index:', batch_idx, end='')
print(' | Batch size:', y.size()[0])
x = x.to(device)
y = y.to(device)
break
Epoch: 1 | Batch index: 0 | Batch size: 128 Epoch: 2 | Batch index: 0 | Batch size: 128
##########################
### MODEL
##########################
class LeNet5(nn.Module):
def __init__(self, num_classes, grayscale=False):
super(LeNet5, self).__init__()
self.grayscale = grayscale
self.num_classes = num_classes
if self.grayscale:
in_channels = 1
else:
in_channels = 3
self.features = nn.Sequential(
nn.Conv2d(in_channels, 6*in_channels, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(6*in_channels, 16*in_channels, kernel_size=5),
nn.MaxPool2d(kernel_size=2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5*in_channels, 120*in_channels),
nn.Linear(120*in_channels, 84*in_channels),
nn.Linear(84*in_channels, num_classes),
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
logits = self.classifier(x)
probas = F.softmax(logits, dim=1)
return logits, probas
torch.manual_seed(RANDOM_SEED)
model = LeNet5(NUM_CLASSES, GRAYSCALE)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def compute_accuracy(model, data_loader, device):
correct_pred, num_examples = 0, 0
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
start_time = time.time()
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
%(epoch+1, NUM_EPOCHS, batch_idx,
len(train_loader), cost))
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
print('Epoch: %03d/%03d | Train: %.3f%%' % (
epoch+1, NUM_EPOCHS,
compute_accuracy(model, train_loader, device=DEVICE)))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/010 | Batch 0000/0391 | Cost: 2.3071 Epoch: 001/010 | Batch 0050/0391 | Cost: 1.8403 Epoch: 001/010 | Batch 0100/0391 | Cost: 1.6112 Epoch: 001/010 | Batch 0150/0391 | Cost: 1.5836 Epoch: 001/010 | Batch 0200/0391 | Cost: 1.4769 Epoch: 001/010 | Batch 0250/0391 | Cost: 1.3934 Epoch: 001/010 | Batch 0300/0391 | Cost: 1.3729 Epoch: 001/010 | Batch 0350/0391 | Cost: 1.2838 Epoch: 001/010 | Train: 52.630% Time elapsed: 0.06 min Epoch: 002/010 | Batch 0000/0391 | Cost: 1.4668 Epoch: 002/010 | Batch 0050/0391 | Cost: 1.2844 Epoch: 002/010 | Batch 0100/0391 | Cost: 1.2798 Epoch: 002/010 | Batch 0150/0391 | Cost: 1.3208 Epoch: 002/010 | Batch 0200/0391 | Cost: 1.2563 Epoch: 002/010 | Batch 0250/0391 | Cost: 1.3398 Epoch: 002/010 | Batch 0300/0391 | Cost: 1.2109 Epoch: 002/010 | Batch 0350/0391 | Cost: 1.3151 Epoch: 002/010 | Train: 56.752% Time elapsed: 0.12 min Epoch: 003/010 | Batch 0000/0391 | Cost: 1.2106 Epoch: 003/010 | Batch 0050/0391 | Cost: 1.1814 Epoch: 003/010 | Batch 0100/0391 | Cost: 1.3497 Epoch: 003/010 | Batch 0150/0391 | Cost: 1.0703 Epoch: 003/010 | Batch 0200/0391 | Cost: 1.1444 Epoch: 003/010 | Batch 0250/0391 | Cost: 0.8451 Epoch: 003/010 | Batch 0300/0391 | Cost: 1.2356 Epoch: 003/010 | Batch 0350/0391 | Cost: 1.2627 Epoch: 003/010 | Train: 62.042% Time elapsed: 0.18 min Epoch: 004/010 | Batch 0000/0391 | Cost: 1.2457 Epoch: 004/010 | Batch 0050/0391 | Cost: 1.1855 Epoch: 004/010 | Batch 0100/0391 | Cost: 1.1870 Epoch: 004/010 | Batch 0150/0391 | Cost: 1.1477 Epoch: 004/010 | Batch 0200/0391 | Cost: 0.9527 Epoch: 004/010 | Batch 0250/0391 | Cost: 1.3219 Epoch: 004/010 | Batch 0300/0391 | Cost: 0.9374 Epoch: 004/010 | Batch 0350/0391 | Cost: 1.0800 Epoch: 004/010 | Train: 62.358% Time elapsed: 0.23 min Epoch: 005/010 | Batch 0000/0391 | Cost: 1.1676 Epoch: 005/010 | Batch 0050/0391 | Cost: 1.0142 Epoch: 005/010 | Batch 0100/0391 | Cost: 1.1620 Epoch: 005/010 | Batch 0150/0391 | Cost: 1.0447 Epoch: 005/010 | Batch 0200/0391 | Cost: 1.0203 Epoch: 005/010 | Batch 0250/0391 | Cost: 1.1567 Epoch: 005/010 | Batch 0300/0391 | Cost: 1.2270 Epoch: 005/010 | Batch 0350/0391 | Cost: 1.2121 Epoch: 005/010 | Train: 65.738% Time elapsed: 0.29 min Epoch: 006/010 | Batch 0000/0391 | Cost: 0.8958 Epoch: 006/010 | Batch 0050/0391 | Cost: 0.8708 Epoch: 006/010 | Batch 0100/0391 | Cost: 0.8954 Epoch: 006/010 | Batch 0150/0391 | Cost: 1.0416 Epoch: 006/010 | Batch 0200/0391 | Cost: 0.9596 Epoch: 006/010 | Batch 0250/0391 | Cost: 1.1908 Epoch: 006/010 | Batch 0300/0391 | Cost: 1.0528 Epoch: 006/010 | Batch 0350/0391 | Cost: 1.1561 Epoch: 006/010 | Train: 65.396% Time elapsed: 0.35 min Epoch: 007/010 | Batch 0000/0391 | Cost: 1.0105 Epoch: 007/010 | Batch 0050/0391 | Cost: 1.0058 Epoch: 007/010 | Batch 0100/0391 | Cost: 1.0195 Epoch: 007/010 | Batch 0150/0391 | Cost: 0.9968 Epoch: 007/010 | Batch 0200/0391 | Cost: 0.9785 Epoch: 007/010 | Batch 0250/0391 | Cost: 0.9639 Epoch: 007/010 | Batch 0300/0391 | Cost: 0.9576 Epoch: 007/010 | Batch 0350/0391 | Cost: 1.1266 Epoch: 007/010 | Train: 67.214% Time elapsed: 0.40 min Epoch: 008/010 | Batch 0000/0391 | Cost: 0.8093 Epoch: 008/010 | Batch 0050/0391 | Cost: 0.9909 Epoch: 008/010 | Batch 0100/0391 | Cost: 0.9171 Epoch: 008/010 | Batch 0150/0391 | Cost: 1.0127 Epoch: 008/010 | Batch 0200/0391 | Cost: 0.8954 Epoch: 008/010 | Batch 0250/0391 | Cost: 1.0231 Epoch: 008/010 | Batch 0300/0391 | Cost: 0.8512 Epoch: 008/010 | Batch 0350/0391 | Cost: 1.2245 Epoch: 008/010 | Train: 66.308% Time elapsed: 0.46 min Epoch: 009/010 | Batch 0000/0391 | Cost: 0.8735 Epoch: 009/010 | Batch 0050/0391 | Cost: 0.8906 Epoch: 009/010 | Batch 0100/0391 | Cost: 0.8600 Epoch: 009/010 | Batch 0150/0391 | Cost: 1.0259 Epoch: 009/010 | Batch 0200/0391 | Cost: 0.9805 Epoch: 009/010 | Batch 0250/0391 | Cost: 0.9791 Epoch: 009/010 | Batch 0300/0391 | Cost: 0.9160 Epoch: 009/010 | Batch 0350/0391 | Cost: 0.9888 Epoch: 009/010 | Train: 69.292% Time elapsed: 0.52 min Epoch: 010/010 | Batch 0000/0391 | Cost: 0.8405 Epoch: 010/010 | Batch 0050/0391 | Cost: 0.9734 Epoch: 010/010 | Batch 0100/0391 | Cost: 1.0093 Epoch: 010/010 | Batch 0150/0391 | Cost: 0.8113 Epoch: 010/010 | Batch 0200/0391 | Cost: 0.9860 Epoch: 010/010 | Batch 0250/0391 | Cost: 0.9763 Epoch: 010/010 | Batch 0300/0391 | Cost: 0.9767 Epoch: 010/010 | Batch 0350/0391 | Cost: 0.9293 Epoch: 010/010 | Train: 69.502% Time elapsed: 0.58 min Total Training Time: 0.58 min
with torch.set_grad_enabled(False): # save memory during inference
print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))
Test accuracy: 61.70%
%watermark -iv
numpy 1.16.4 pandas 0.24.2 matplotlib 3.1.0 torchvision 0.4.0a0+6b959ee PIL.Image 6.0.0 torch 1.2.0