Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 torch 1.2.0
The network in this notebook is an implementation of the DenseNet-121 [1] architecture on the MNIST digits dataset (http://yann.lecun.com/exdb/mnist/) to train a handwritten digit classifier.
The following figure illustrates the main concept of DenseNet: within each "dense" block, each layer is connected with each previous layer -- the feature maps are concatenated.
Note that this is somewhat related yet very different to ResNets. ResNets have skip connections approx. between every other layer (but don't connect all layers with each other). Also, ResNets skip connections work via addition
xℓ=Hℓ(Xℓ−1)+Xℓ−1,
whereas Hℓ(⋅) can be a composite function of operations such as Batch Normalization (BN), rectified linear units (ReLU), Pooling, or Convolution (Conv).
In DenseNets, all the previous feature maps X0,…,Xℓ−1 of a feature map Xℓ are concatenated:
xℓ=Hℓ([x0,x1,…,xℓ−1]).Furthermore, in this particular notebook, we are considering the DenseNet-121, which is depicted below:
References
[1] Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4700-4708), http://openaccess.thecvf.com/content_cvpr_2017/html/Huang_Densely_Connected_Convolutional_CVPR_2017_paper.html
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
import matplotlib.pyplot as plt
%matplotlib inline
##########################
### SETTINGS
##########################
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 128
NUM_EPOCHS = 10
# Architecture
NUM_CLASSES = 10
# Other
DEVICE = "cuda:0"
GRAYSCALE = True
train_indices = torch.arange(0, 59000)
valid_indices = torch.arange(59000, 60000)
resize_transform = transforms.Compose([transforms.Resize((32, 32)),
transforms.ToTensor()])
train_and_valid = datasets.MNIST(root='data',
train=True,
transform=resize_transform,
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=resize_transform,
download=True)
train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
device = torch.device(DEVICE)
torch.manual_seed(0)
for epoch in range(2):
for batch_idx, (x, y) in enumerate(train_loader):
print('Epoch:', epoch+1, end='')
print(' | Batch index:', batch_idx, end='')
print(' | Batch size:', y.size()[0])
x = x.to(device)
y = y.to(device)
break
Epoch: 1 | Batch index: 0 | Batch size: 128 Epoch: 2 | Batch index: 0 | Batch size: 128
# Check that shuffling works properly
# i.e., label indices should be in random order.
# Also, the label order should be different in the second
# epoch.
for images, labels in train_loader:
pass
print(labels[:10])
for images, labels in train_loader:
pass
print(labels[:10])
tensor([1, 9, 7, 6, 7, 9, 2, 0, 4, 3]) tensor([5, 4, 4, 7, 4, 6, 8, 7, 1, 3])
# Check that validation set and test sets are diverse
# i.e., that they contain all classes
for images, labels in valid_loader:
pass
print(labels[:10])
for images, labels in test_loader:
pass
print(labels[:10])
tensor([2, 2, 0, 9, 4, 0, 1, 2, 3, 4]) tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
##########################
### MODEL
##########################
# The following code cell that implements the DenseNet-121 architecture
# is a derivative of the code provided at
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return new_features
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet121(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_featuremaps (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_featuremaps=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False,
grayscale=False):
super(DenseNet121, self).__init__()
# First convolution
if grayscale:
in_channels=1
else:
in_channels=3
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_channels=in_channels, out_channels=num_init_featuremaps,
kernel_size=7, stride=2,
padding=3, bias=False)), # bias is redundant when using batchnorm
('norm0', nn.BatchNorm2d(num_features=num_init_featuremaps)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_featuremaps
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
logits = self.classifier(out)
probas = F.softmax(logits, dim=1)
return logits, probas
torch.manual_seed(RANDOM_SEED)
model = DenseNet121(num_classes=NUM_CLASSES, grayscale=GRAYSCALE)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def compute_acc(model, data_loader, device):
correct_pred, num_examples = 0, 0
model.eval()
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
assert predicted_labels.size() == targets.size()
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
start_time = time.time()
cost_list = []
train_acc_list, valid_acc_list = [], []
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
#################################################
### CODE ONLY FOR LOGGING BEYOND THIS POINT
################################################
cost_list.append(cost.item())
if not batch_idx % 150:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f}')
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
train_acc = compute_acc(model, train_loader, device=DEVICE)
valid_acc = compute_acc(model, valid_loader, device=DEVICE)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/010 | Batch 000/461 | Cost: 2.3666 Epoch: 001/010 | Batch 150/461 | Cost: 0.1477 Epoch: 001/010 | Batch 300/461 | Cost: 0.0662 Epoch: 001/010 | Batch 450/461 | Cost: 0.1357 Epoch: 001/010 Train ACC: 98.87 | Validation ACC: 98.10 Time elapsed: 0.93 min Epoch: 002/010 | Batch 000/461 | Cost: 0.0476 Epoch: 002/010 | Batch 150/461 | Cost: 0.0443 Epoch: 002/010 | Batch 300/461 | Cost: 0.0520 Epoch: 002/010 | Batch 450/461 | Cost: 0.0377 Epoch: 002/010 Train ACC: 99.58 | Validation ACC: 98.50 Time elapsed: 2.01 min Epoch: 003/010 | Batch 000/461 | Cost: 0.0080 Epoch: 003/010 | Batch 150/461 | Cost: 0.0085 Epoch: 003/010 | Batch 300/461 | Cost: 0.0125 Epoch: 003/010 | Batch 450/461 | Cost: 0.0178 Epoch: 003/010 Train ACC: 99.57 | Validation ACC: 98.90 Time elapsed: 3.10 min Epoch: 004/010 | Batch 000/461 | Cost: 0.0057 Epoch: 004/010 | Batch 150/461 | Cost: 0.0368 Epoch: 004/010 | Batch 300/461 | Cost: 0.0026 Epoch: 004/010 | Batch 450/461 | Cost: 0.1089 Epoch: 004/010 Train ACC: 99.79 | Validation ACC: 98.90 Time elapsed: 4.17 min Epoch: 005/010 | Batch 000/461 | Cost: 0.0026 Epoch: 005/010 | Batch 150/461 | Cost: 0.0098 Epoch: 005/010 | Batch 300/461 | Cost: 0.0161 Epoch: 005/010 | Batch 450/461 | Cost: 0.0116 Epoch: 005/010 Train ACC: 99.77 | Validation ACC: 99.10 Time elapsed: 5.23 min Epoch: 006/010 | Batch 000/461 | Cost: 0.0033 Epoch: 006/010 | Batch 150/461 | Cost: 0.0009 Epoch: 006/010 | Batch 300/461 | Cost: 0.0114 Epoch: 006/010 | Batch 450/461 | Cost: 0.0582 Epoch: 006/010 Train ACC: 99.59 | Validation ACC: 98.50 Time elapsed: 6.33 min Epoch: 007/010 | Batch 000/461 | Cost: 0.0034 Epoch: 007/010 | Batch 150/461 | Cost: 0.0059 Epoch: 007/010 | Batch 300/461 | Cost: 0.0064 Epoch: 007/010 | Batch 450/461 | Cost: 0.0239 Epoch: 007/010 Train ACC: 99.69 | Validation ACC: 97.90 Time elapsed: 7.42 min Epoch: 008/010 | Batch 000/461 | Cost: 0.0042 Epoch: 008/010 | Batch 150/461 | Cost: 0.0282 Epoch: 008/010 | Batch 300/461 | Cost: 0.0022 Epoch: 008/010 | Batch 450/461 | Cost: 0.0025 Epoch: 008/010 Train ACC: 99.85 | Validation ACC: 99.10 Time elapsed: 8.51 min Epoch: 009/010 | Batch 000/461 | Cost: 0.0090 Epoch: 009/010 | Batch 150/461 | Cost: 0.0034 Epoch: 009/010 | Batch 300/461 | Cost: 0.0026 Epoch: 009/010 | Batch 450/461 | Cost: 0.0042 Epoch: 009/010 Train ACC: 99.81 | Validation ACC: 98.30 Time elapsed: 9.53 min Epoch: 010/010 | Batch 000/461 | Cost: 0.0413 Epoch: 010/010 | Batch 150/461 | Cost: 0.0042 Epoch: 010/010 | Batch 300/461 | Cost: 0.0021 Epoch: 010/010 | Batch 450/461 | Cost: 0.0040 Epoch: 010/010 Train ACC: 99.88 | Validation ACC: 98.80 Time elapsed: 10.63 min Total Training Time: 10.63 min
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
test_acc = compute_acc(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_acc(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 98.80% Test ACC: 98.95%
%watermark -iv
pandas 0.24.2 PIL.Image 6.0.0 torch 1.2.0 numpy 1.16.4 torchvision 0.4.0a0+6b959ee matplotlib 3.1.0 re 2.2.1