Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 torch 1.1.0
The network in this notebook is an implementation of the DenseNet-121 [1] architecture on the MNIST digits dataset (http://yann.lecun.com/exdb/mnist/) to train a handwritten digit classifier.
The following figure illustrates the main concept of DenseNet: within each "dense" block, each layer is connected with each previous layer -- the feature maps are concatenated.
Note that this is somewhat related yet very different to ResNets. ResNets have skip connections approx. between every other layer (but don't connect all layers with each other). Also, ResNets skip connections work via addition
xℓ=Hℓ(Xℓ−1)+Xℓ−1,
whereas Hℓ(⋅) can be a composite function of operations such as Batch Normalization (BN), rectified linear units (ReLU), Pooling, or Convolution (Conv).
In DenseNets, all the previous feature maps X0,…,Xℓ−1 of a feature map Xℓ are concatenated:
xℓ=Hℓ([x0,x1,…,xℓ−1]).Furthermore, in this particular notebook, we are considering the DenseNet-121, which is depicted below:
References
[1] Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4700-4708), http://openaccess.thecvf.com/content_cvpr_2017/html/Huang_Densely_Connected_Convolutional_CVPR_2017_paper.html
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
import matplotlib.pyplot as plt
%matplotlib inline
##########################
### SETTINGS
##########################
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 128
NUM_EPOCHS = 20
# Architecture
NUM_CLASSES = 10
# Other
DEVICE = "cuda:0"
GRAYSCALE = False
train_indices = torch.arange(0, 48000)
valid_indices = torch.arange(48000, 50000)
train_and_valid = datasets.CIFAR10(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=transforms.ToTensor(),
download=False)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
Files already downloaded and verified
device = torch.device(DEVICE)
torch.manual_seed(0)
for epoch in range(2):
for batch_idx, (x, y) in enumerate(train_loader):
print('Epoch:', epoch+1, end='')
print(' | Batch index:', batch_idx, end='')
print(' | Batch size:', y.size()[0])
x = x.to(device)
y = y.to(device)
break
Epoch: 1 | Batch index: 0 | Batch size: 128 Epoch: 2 | Batch index: 0 | Batch size: 128
# Check that shuffling works properly
# i.e., label indices should be in random order.
# Also, the label order should be different in the second
# epoch.
for images, labels in train_loader:
pass
print(labels[:10])
for images, labels in train_loader:
pass
print(labels[:10])
tensor([3, 0, 1, 3, 3, 5, 0, 4, 9, 4]) tensor([1, 0, 4, 1, 8, 2, 0, 3, 5, 3])
# Check that validation set and test sets are diverse
# i.e., that they contain all classes
for images, labels in valid_loader:
pass
print(labels[:10])
for images, labels in test_loader:
pass
print(labels[:10])
tensor([5, 0, 3, 6, 8, 7, 9, 5, 6, 6]) tensor([7, 5, 8, 0, 8, 2, 7, 0, 3, 5])
##########################
### MODEL
##########################
# The following code cell that implements the DenseNet-121 architecture
# is a derivative of the code provided at
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return new_features
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet121(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_featuremaps (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_featuremaps=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False,
grayscale=False):
super(DenseNet121, self).__init__()
# First convolution
if grayscale:
in_channels=1
else:
in_channels=3
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_channels=in_channels, out_channels=num_init_featuremaps,
kernel_size=7, stride=2,
padding=3, bias=False)), # bias is redundant when using batchnorm
('norm0', nn.BatchNorm2d(num_features=num_init_featuremaps)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_featuremaps
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
logits = self.classifier(out)
probas = F.softmax(logits, dim=1)
return logits, probas
torch.manual_seed(RANDOM_SEED)
model = DenseNet121(num_classes=NUM_CLASSES, grayscale=GRAYSCALE)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def compute_acc(model, data_loader, device):
correct_pred, num_examples = 0, 0
model.eval()
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
assert predicted_labels.size() == targets.size()
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
start_time = time.time()
cost_list = []
train_acc_list, valid_acc_list = [], []
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
#################################################
### CODE ONLY FOR LOGGING BEYOND THIS POINT
################################################
cost_list.append(cost.item())
if not batch_idx % 150:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f}')
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
train_acc = compute_acc(model, train_loader, device=DEVICE)
valid_acc = compute_acc(model, valid_loader, device=DEVICE)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/020 | Batch 000/375 | Cost: 2.4002 Epoch: 001/020 | Batch 150/375 | Cost: 1.3511 Epoch: 001/020 | Batch 300/375 | Cost: 1.3326 Epoch: 001/020 Train ACC: 61.52 | Validation ACC: 60.25 Time elapsed: 0.99 min Epoch: 002/020 | Batch 000/375 | Cost: 1.1005 Epoch: 002/020 | Batch 150/375 | Cost: 1.0096 Epoch: 002/020 | Batch 300/375 | Cost: 1.2628 Epoch: 002/020 Train ACC: 66.19 | Validation ACC: 63.55 Time elapsed: 1.90 min Epoch: 003/020 | Batch 000/375 | Cost: 0.6471 Epoch: 003/020 | Batch 150/375 | Cost: 0.8107 Epoch: 003/020 | Batch 300/375 | Cost: 0.8848 Epoch: 003/020 Train ACC: 73.32 | Validation ACC: 67.40 Time elapsed: 2.79 min Epoch: 004/020 | Batch 000/375 | Cost: 0.6951 Epoch: 004/020 | Batch 150/375 | Cost: 0.6852 Epoch: 004/020 | Batch 300/375 | Cost: 0.8775 Epoch: 004/020 Train ACC: 76.38 | Validation ACC: 67.90 Time elapsed: 3.65 min Epoch: 005/020 | Batch 000/375 | Cost: 0.4929 Epoch: 005/020 | Batch 150/375 | Cost: 0.5942 Epoch: 005/020 | Batch 300/375 | Cost: 0.6016 Epoch: 005/020 Train ACC: 76.70 | Validation ACC: 67.95 Time elapsed: 4.53 min Epoch: 006/020 | Batch 000/375 | Cost: 0.5888 Epoch: 006/020 | Batch 150/375 | Cost: 0.5500 Epoch: 006/020 | Batch 300/375 | Cost: 0.4939 Epoch: 006/020 Train ACC: 84.21 | Validation ACC: 74.85 Time elapsed: 5.51 min Epoch: 007/020 | Batch 000/375 | Cost: 0.3982 Epoch: 007/020 | Batch 150/375 | Cost: 0.4272 Epoch: 007/020 | Batch 300/375 | Cost: 0.3908 Epoch: 007/020 Train ACC: 87.29 | Validation ACC: 74.65 Time elapsed: 6.43 min Epoch: 008/020 | Batch 000/375 | Cost: 0.4332 Epoch: 008/020 | Batch 150/375 | Cost: 0.2335 Epoch: 008/020 | Batch 300/375 | Cost: 0.3678 Epoch: 008/020 Train ACC: 84.10 | Validation ACC: 71.00 Time elapsed: 7.35 min Epoch: 009/020 | Batch 000/375 | Cost: 0.2343 Epoch: 009/020 | Batch 150/375 | Cost: 0.2704 Epoch: 009/020 | Batch 300/375 | Cost: 0.3429 Epoch: 009/020 Train ACC: 88.51 | Validation ACC: 74.35 Time elapsed: 8.26 min Epoch: 010/020 | Batch 000/375 | Cost: 0.1757 Epoch: 010/020 | Batch 150/375 | Cost: 0.1748 Epoch: 010/020 | Batch 300/375 | Cost: 0.4755 Epoch: 010/020 Train ACC: 92.77 | Validation ACC: 74.20 Time elapsed: 9.27 min Epoch: 011/020 | Batch 000/375 | Cost: 0.2347 Epoch: 011/020 | Batch 150/375 | Cost: 0.1618 Epoch: 011/020 | Batch 300/375 | Cost: 0.3075 Epoch: 011/020 Train ACC: 90.78 | Validation ACC: 74.10 Time elapsed: 10.24 min Epoch: 012/020 | Batch 000/375 | Cost: 0.1233 Epoch: 012/020 | Batch 150/375 | Cost: 0.1385 Epoch: 012/020 | Batch 300/375 | Cost: 0.2406 Epoch: 012/020 Train ACC: 93.26 | Validation ACC: 76.00 Time elapsed: 11.15 min Epoch: 013/020 | Batch 000/375 | Cost: 0.0757 Epoch: 013/020 | Batch 150/375 | Cost: 0.0658 Epoch: 013/020 | Batch 300/375 | Cost: 0.2070 Epoch: 013/020 Train ACC: 91.64 | Validation ACC: 75.80 Time elapsed: 12.13 min Epoch: 014/020 | Batch 000/375 | Cost: 0.1285 Epoch: 014/020 | Batch 150/375 | Cost: 0.1468 Epoch: 014/020 | Batch 300/375 | Cost: 0.1070 Epoch: 014/020 Train ACC: 95.18 | Validation ACC: 77.80 Time elapsed: 13.10 min Epoch: 015/020 | Batch 000/375 | Cost: 0.1204 Epoch: 015/020 | Batch 150/375 | Cost: 0.0461 Epoch: 015/020 | Batch 300/375 | Cost: 0.1005 Epoch: 015/020 Train ACC: 93.26 | Validation ACC: 74.80 Time elapsed: 14.10 min Epoch: 016/020 | Batch 000/375 | Cost: 0.0591 Epoch: 016/020 | Batch 150/375 | Cost: 0.1024 Epoch: 016/020 | Batch 300/375 | Cost: 0.0679 Epoch: 016/020 Train ACC: 95.66 | Validation ACC: 76.80 Time elapsed: 15.02 min Epoch: 017/020 | Batch 000/375 | Cost: 0.0359 Epoch: 017/020 | Batch 150/375 | Cost: 0.0615 Epoch: 017/020 | Batch 300/375 | Cost: 0.0725 Epoch: 017/020 Train ACC: 92.06 | Validation ACC: 75.85 Time elapsed: 16.01 min Epoch: 018/020 | Batch 000/375 | Cost: 0.0831 Epoch: 018/020 | Batch 150/375 | Cost: 0.1023 Epoch: 018/020 | Batch 300/375 | Cost: 0.0666 Epoch: 018/020 Train ACC: 94.32 | Validation ACC: 75.10 Time elapsed: 17.01 min Epoch: 019/020 | Batch 000/375 | Cost: 0.0499 Epoch: 019/020 | Batch 150/375 | Cost: 0.1068 Epoch: 019/020 | Batch 300/375 | Cost: 0.0657 Epoch: 019/020 Train ACC: 96.89 | Validation ACC: 77.75 Time elapsed: 17.99 min Epoch: 020/020 | Batch 000/375 | Cost: 0.0751 Epoch: 020/020 | Batch 150/375 | Cost: 0.0423 Epoch: 020/020 | Batch 300/375 | Cost: 0.0398 Epoch: 020/020 Train ACC: 94.66 | Validation ACC: 76.70 Time elapsed: 18.90 min Total Training Time: 18.90 min
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
test_acc = compute_acc(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_acc(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 76.70% Test ACC: 74.97%
%watermark -iv
torch 1.1.0 matplotlib 3.1.0 pandas 0.24.2 torchvision 0.3.0 numpy 1.16.4 re 2.2.1 PIL.Image 6.0.0