Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 torch 1.1.0
This is a notebook experimenting with increasing the batch size during training, which is inspired by the paper
To summarize the main points of the paper:
Stochastic gradient descent adds noise to the optimization problem; during the early training epochs, this noise helps with exploring the loss landscape, and in general, it helps with escaping sharp minima which are known to be bad for generalization.
However, during the course of the training process, one wants to decay the learning rate gradually (like simulated annealing) for fine-tuning, i.e., to help with convergence
Due to the relationship between learning rate, batch size, and momentum, one can also just increase the batch size instead of decreasing the learning rate to reduce the noise. This way, more training examples can be used in each update and fewer steps (parameter updates) overall may be required to converge.
The relationship between learning rate and batch size is as follows:
g=ϵ(NB−1),where ϵ is the learning rate, B is the batch size, and N is the number of training examples
Or, with added momentum term, this becomes:
g=ϵ1−m(NB−1)≈ϵNB(1−m).In this notebook, the CIFAR-10 dataset is used for training a classic AlexNet network [1] for classification:
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
import matplotlib.pyplot as plt
%matplotlib inline
##########################
### SETTINGS
##########################
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 256
NUM_EPOCHS = 40
# Architecture
NUM_CLASSES = 10
# Other
DEVICE = "cuda:0"
train_indices = torch.arange(0, 48000)
valid_indices = torch.arange(48000, 50000)
train_transform = transforms.Compose([transforms.Resize((70, 70)),
transforms.RandomCrop((64, 64)),
transforms.ToTensor()])
test_transform = transforms.Compose([transforms.Resize((70, 70)),
transforms.CenterCrop((64, 64)),
transforms.ToTensor()])
train_and_valid = datasets.CIFAR10(root='data',
train=True,
transform=train_transform,
download=True)
train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=test_transform,
download=False)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
Files already downloaded and verified
# Checking the dataset
print('Training Set:\n')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
break
# Checking the dataset
print('\nValidation Set:')
for images, labels in valid_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
break
# Checking the dataset
print('\nTesting Set:')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
break
Training Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256]) Validation Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256]) Testing Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256])
##########################
### MODEL
##########################
class AlexNet(nn.Module):
def __init__(self, num_classes):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
logits = self.classifier(x)
probas = F.softmax(logits, dim=1)
return logits, probas
def compute_acc(model, data_loader, device):
correct_pred, num_examples = 0, 0
model.eval()
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
assert predicted_labels.size() == targets.size()
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
torch.manual_seed(RANDOM_SEED)
model = AlexNet(NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
start_time = time.time()
cost_list = []
train_acc_list, valid_acc_list = [], []
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
#################################################
### CODE ONLY FOR LOGGING BEYOND THIS POINT
################################################
cost_list.append(cost.item())
if not batch_idx % 150:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f}')
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
train_acc = compute_acc(model, train_loader, device=DEVICE)
valid_acc = compute_acc(model, valid_loader, device=DEVICE)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/040 | Batch 000/188 | Cost: 2.3029 Epoch: 001/040 | Batch 150/188 | Cost: 1.7122 Epoch: 001/040 Train ACC: 32.64 | Validation ACC: 31.45 Time elapsed: 0.21 min Epoch: 002/040 | Batch 000/188 | Cost: 1.7477 Epoch: 002/040 | Batch 150/188 | Cost: 1.6831 Epoch: 002/040 Train ACC: 44.09 | Validation ACC: 43.40 Time elapsed: 0.42 min Epoch: 003/040 | Batch 000/188 | Cost: 1.5064 Epoch: 003/040 | Batch 150/188 | Cost: 1.4504 Epoch: 003/040 Train ACC: 51.37 | Validation ACC: 51.00 Time elapsed: 0.63 min Epoch: 004/040 | Batch 000/188 | Cost: 1.4089 Epoch: 004/040 | Batch 150/188 | Cost: 1.2423 Epoch: 004/040 Train ACC: 58.56 | Validation ACC: 57.75 Time elapsed: 0.84 min Epoch: 005/040 | Batch 000/188 | Cost: 1.0506 Epoch: 005/040 | Batch 150/188 | Cost: 1.1601 Epoch: 005/040 Train ACC: 59.14 | Validation ACC: 58.55 Time elapsed: 1.06 min Epoch: 006/040 | Batch 000/188 | Cost: 1.0774 Epoch: 006/040 | Batch 150/188 | Cost: 1.1084 Epoch: 006/040 Train ACC: 62.25 | Validation ACC: 60.95 Time elapsed: 1.27 min Epoch: 007/040 | Batch 000/188 | Cost: 1.0387 Epoch: 007/040 | Batch 150/188 | Cost: 1.0570 Epoch: 007/040 Train ACC: 65.41 | Validation ACC: 63.20 Time elapsed: 1.49 min Epoch: 008/040 | Batch 000/188 | Cost: 1.0650 Epoch: 008/040 | Batch 150/188 | Cost: 0.9280 Epoch: 008/040 Train ACC: 64.37 | Validation ACC: 63.95 Time elapsed: 1.70 min Epoch: 009/040 | Batch 000/188 | Cost: 1.0195 Epoch: 009/040 | Batch 150/188 | Cost: 0.7793 Epoch: 009/040 Train ACC: 69.71 | Validation ACC: 67.30 Time elapsed: 1.91 min Epoch: 010/040 | Batch 000/188 | Cost: 0.7986 Epoch: 010/040 | Batch 150/188 | Cost: 0.7988 Epoch: 010/040 Train ACC: 69.41 | Validation ACC: 65.45 Time elapsed: 2.12 min Epoch: 011/040 | Batch 000/188 | Cost: 0.8688 Epoch: 011/040 | Batch 150/188 | Cost: 0.7943 Epoch: 011/040 Train ACC: 70.95 | Validation ACC: 67.35 Time elapsed: 2.34 min Epoch: 012/040 | Batch 000/188 | Cost: 0.7696 Epoch: 012/040 | Batch 150/188 | Cost: 0.8943 Epoch: 012/040 Train ACC: 75.26 | Validation ACC: 67.95 Time elapsed: 2.55 min Epoch: 013/040 | Batch 000/188 | Cost: 0.6622 Epoch: 013/040 | Batch 150/188 | Cost: 0.7226 Epoch: 013/040 Train ACC: 77.99 | Validation ACC: 72.45 Time elapsed: 2.76 min Epoch: 014/040 | Batch 000/188 | Cost: 0.6180 Epoch: 014/040 | Batch 150/188 | Cost: 0.6502 Epoch: 014/040 Train ACC: 77.82 | Validation ACC: 70.85 Time elapsed: 2.97 min Epoch: 015/040 | Batch 000/188 | Cost: 0.6359 Epoch: 015/040 | Batch 150/188 | Cost: 0.8206 Epoch: 015/040 Train ACC: 79.41 | Validation ACC: 71.35 Time elapsed: 3.18 min Epoch: 016/040 | Batch 000/188 | Cost: 0.6694 Epoch: 016/040 | Batch 150/188 | Cost: 0.5700 Epoch: 016/040 Train ACC: 79.59 | Validation ACC: 70.75 Time elapsed: 3.39 min Epoch: 017/040 | Batch 000/188 | Cost: 0.6395 Epoch: 017/040 | Batch 150/188 | Cost: 0.5564 Epoch: 017/040 Train ACC: 82.24 | Validation ACC: 72.75 Time elapsed: 3.61 min Epoch: 018/040 | Batch 000/188 | Cost: 0.5724 Epoch: 018/040 | Batch 150/188 | Cost: 0.4650 Epoch: 018/040 Train ACC: 83.02 | Validation ACC: 71.55 Time elapsed: 3.82 min Epoch: 019/040 | Batch 000/188 | Cost: 0.4790 Epoch: 019/040 | Batch 150/188 | Cost: 0.4548 Epoch: 019/040 Train ACC: 84.87 | Validation ACC: 73.35 Time elapsed: 4.03 min Epoch: 020/040 | Batch 000/188 | Cost: 0.4254 Epoch: 020/040 | Batch 150/188 | Cost: 0.4183 Epoch: 020/040 Train ACC: 85.73 | Validation ACC: 72.55 Time elapsed: 4.25 min Epoch: 021/040 | Batch 000/188 | Cost: 0.5254 Epoch: 021/040 | Batch 150/188 | Cost: 0.4328 Epoch: 021/040 Train ACC: 85.22 | Validation ACC: 72.25 Time elapsed: 4.46 min Epoch: 022/040 | Batch 000/188 | Cost: 0.4798 Epoch: 022/040 | Batch 150/188 | Cost: 0.4075 Epoch: 022/040 Train ACC: 88.92 | Validation ACC: 73.90 Time elapsed: 4.68 min Epoch: 023/040 | Batch 000/188 | Cost: 0.2946 Epoch: 023/040 | Batch 150/188 | Cost: 0.3808 Epoch: 023/040 Train ACC: 89.33 | Validation ACC: 73.80 Time elapsed: 4.88 min Epoch: 024/040 | Batch 000/188 | Cost: 0.2511 Epoch: 024/040 | Batch 150/188 | Cost: 0.3758 Epoch: 024/040 Train ACC: 89.94 | Validation ACC: 74.20 Time elapsed: 5.10 min Epoch: 025/040 | Batch 000/188 | Cost: 0.2348 Epoch: 025/040 | Batch 150/188 | Cost: 0.4043 Epoch: 025/040 Train ACC: 90.37 | Validation ACC: 74.10 Time elapsed: 5.31 min Epoch: 026/040 | Batch 000/188 | Cost: 0.2663 Epoch: 026/040 | Batch 150/188 | Cost: 0.2651 Epoch: 026/040 Train ACC: 91.69 | Validation ACC: 72.55 Time elapsed: 5.52 min Epoch: 027/040 | Batch 000/188 | Cost: 0.2907 Epoch: 027/040 | Batch 150/188 | Cost: 0.2981 Epoch: 027/040 Train ACC: 92.33 | Validation ACC: 73.10 Time elapsed: 5.74 min Epoch: 028/040 | Batch 000/188 | Cost: 0.2318 Epoch: 028/040 | Batch 150/188 | Cost: 0.2904 Epoch: 028/040 Train ACC: 91.91 | Validation ACC: 72.10 Time elapsed: 5.95 min Epoch: 029/040 | Batch 000/188 | Cost: 0.1949 Epoch: 029/040 | Batch 150/188 | Cost: 0.1721 Epoch: 029/040 Train ACC: 93.64 | Validation ACC: 73.15 Time elapsed: 6.16 min Epoch: 030/040 | Batch 000/188 | Cost: 0.1504 Epoch: 030/040 | Batch 150/188 | Cost: 0.2986 Epoch: 030/040 Train ACC: 94.12 | Validation ACC: 73.50 Time elapsed: 6.37 min Epoch: 031/040 | Batch 000/188 | Cost: 0.1666 Epoch: 031/040 | Batch 150/188 | Cost: 0.1380 Epoch: 031/040 Train ACC: 92.82 | Validation ACC: 72.90 Time elapsed: 6.59 min Epoch: 032/040 | Batch 000/188 | Cost: 0.2123 Epoch: 032/040 | Batch 150/188 | Cost: 0.2601 Epoch: 032/040 Train ACC: 94.51 | Validation ACC: 72.80 Time elapsed: 6.80 min Epoch: 033/040 | Batch 000/188 | Cost: 0.1769 Epoch: 033/040 | Batch 150/188 | Cost: 0.1912 Epoch: 033/040 Train ACC: 94.81 | Validation ACC: 72.15 Time elapsed: 7.01 min Epoch: 034/040 | Batch 000/188 | Cost: 0.2098 Epoch: 034/040 | Batch 150/188 | Cost: 0.2454 Epoch: 034/040 Train ACC: 95.87 | Validation ACC: 73.25 Time elapsed: 7.22 min Epoch: 035/040 | Batch 000/188 | Cost: 0.1446 Epoch: 035/040 | Batch 150/188 | Cost: 0.1103 Epoch: 035/040 Train ACC: 94.59 | Validation ACC: 72.05 Time elapsed: 7.43 min Epoch: 036/040 | Batch 000/188 | Cost: 0.1118 Epoch: 036/040 | Batch 150/188 | Cost: 0.1148 Epoch: 036/040 Train ACC: 96.36 | Validation ACC: 74.30 Time elapsed: 7.65 min Epoch: 037/040 | Batch 000/188 | Cost: 0.1138 Epoch: 037/040 | Batch 150/188 | Cost: 0.2091 Epoch: 037/040 Train ACC: 95.63 | Validation ACC: 73.35 Time elapsed: 7.85 min Epoch: 038/040 | Batch 000/188 | Cost: 0.1720 Epoch: 038/040 | Batch 150/188 | Cost: 0.0837 Epoch: 038/040 Train ACC: 95.77 | Validation ACC: 74.10 Time elapsed: 8.07 min Epoch: 039/040 | Batch 000/188 | Cost: 0.1058 Epoch: 039/040 | Batch 150/188 | Cost: 0.0731 Epoch: 039/040 Train ACC: 97.03 | Validation ACC: 73.55 Time elapsed: 8.28 min Epoch: 040/040 | Batch 000/188 | Cost: 0.1014 Epoch: 040/040 | Batch 150/188 | Cost: 0.1611 Epoch: 040/040 Train ACC: 96.68 | Validation ACC: 72.30 Time elapsed: 8.49 min Total Training Time: 8.49 min
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
test_acc = compute_acc(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_acc(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 71.60% Test ACC: 72.37%
torch.manual_seed(RANDOM_SEED)
model = AlexNet(NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
batch_sizes = np.arange(256, 5121, 512)
batch_size_index = 0
start_time = time.time()
cost_list = []
train_acc_list, valid_acc_list = [], []
for epoch in range(NUM_EPOCHS):
### INCREASE BATCH SIZE
if epoch > (NUM_EPOCHS//2) and not epoch % (NUM_EPOCHS//len(batch_sizes)):
train_loader = DataLoader(dataset=train_dataset,
batch_size=int(batch_sizes[batch_size_index]),
num_workers=4,
shuffle=True)
batch_size_index += 1
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
#################################################
### CODE ONLY FOR LOGGING BEYOND THIS POINT
################################################
cost_list.append(cost.item())
if not batch_idx % 150:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f} | Batchsize: {batch_sizes[batch_size_index]}')
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
train_acc = compute_acc(model, train_loader, device=DEVICE)
valid_acc = compute_acc(model, valid_loader, device=DEVICE)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/040 | Batch 000/188 | Cost: 2.3029 | Batchsize: 256 Epoch: 001/040 | Batch 150/188 | Cost: 1.7115 | Batchsize: 256 Epoch: 001/040 Train ACC: 33.36 | Validation ACC: 33.00 Time elapsed: 0.21 min Epoch: 002/040 | Batch 000/188 | Cost: 1.7286 | Batchsize: 256 Epoch: 002/040 | Batch 150/188 | Cost: 1.6143 | Batchsize: 256 Epoch: 002/040 Train ACC: 44.84 | Validation ACC: 45.55 Time elapsed: 0.42 min Epoch: 003/040 | Batch 000/188 | Cost: 1.5018 | Batchsize: 256 Epoch: 003/040 | Batch 150/188 | Cost: 1.4893 | Batchsize: 256 Epoch: 003/040 Train ACC: 50.73 | Validation ACC: 50.55 Time elapsed: 0.64 min Epoch: 004/040 | Batch 000/188 | Cost: 1.4247 | Batchsize: 256 Epoch: 004/040 | Batch 150/188 | Cost: 1.2653 | Batchsize: 256 Epoch: 004/040 Train ACC: 56.70 | Validation ACC: 57.35 Time elapsed: 0.85 min Epoch: 005/040 | Batch 000/188 | Cost: 1.0885 | Batchsize: 256 Epoch: 005/040 | Batch 150/188 | Cost: 1.1472 | Batchsize: 256 Epoch: 005/040 Train ACC: 60.30 | Validation ACC: 58.45 Time elapsed: 1.06 min Epoch: 006/040 | Batch 000/188 | Cost: 1.0394 | Batchsize: 256 Epoch: 006/040 | Batch 150/188 | Cost: 1.0907 | Batchsize: 256 Epoch: 006/040 Train ACC: 62.62 | Validation ACC: 60.85 Time elapsed: 1.27 min Epoch: 007/040 | Batch 000/188 | Cost: 1.0348 | Batchsize: 256 Epoch: 007/040 | Batch 150/188 | Cost: 1.0401 | Batchsize: 256 Epoch: 007/040 Train ACC: 66.19 | Validation ACC: 65.60 Time elapsed: 1.48 min Epoch: 008/040 | Batch 000/188 | Cost: 1.0627 | Batchsize: 256 Epoch: 008/040 | Batch 150/188 | Cost: 0.9297 | Batchsize: 256 Epoch: 008/040 Train ACC: 64.11 | Validation ACC: 63.90 Time elapsed: 1.70 min Epoch: 009/040 | Batch 000/188 | Cost: 1.0361 | Batchsize: 256 Epoch: 009/040 | Batch 150/188 | Cost: 0.8127 | Batchsize: 256 Epoch: 009/040 Train ACC: 69.89 | Validation ACC: 65.45 Time elapsed: 1.91 min Epoch: 010/040 | Batch 000/188 | Cost: 0.7913 | Batchsize: 256 Epoch: 010/040 | Batch 150/188 | Cost: 0.7620 | Batchsize: 256 Epoch: 010/040 Train ACC: 69.22 | Validation ACC: 66.50 Time elapsed: 2.12 min Epoch: 011/040 | Batch 000/188 | Cost: 0.8304 | Batchsize: 256 Epoch: 011/040 | Batch 150/188 | Cost: 0.8406 | Batchsize: 256 Epoch: 011/040 Train ACC: 71.92 | Validation ACC: 68.50 Time elapsed: 2.33 min Epoch: 012/040 | Batch 000/188 | Cost: 0.6939 | Batchsize: 256 Epoch: 012/040 | Batch 150/188 | Cost: 0.9586 | Batchsize: 256 Epoch: 012/040 Train ACC: 73.86 | Validation ACC: 67.45 Time elapsed: 2.54 min Epoch: 013/040 | Batch 000/188 | Cost: 0.7050 | Batchsize: 256 Epoch: 013/040 | Batch 150/188 | Cost: 0.6281 | Batchsize: 256 Epoch: 013/040 Train ACC: 77.54 | Validation ACC: 70.90 Time elapsed: 2.76 min Epoch: 014/040 | Batch 000/188 | Cost: 0.6453 | Batchsize: 256 Epoch: 014/040 | Batch 150/188 | Cost: 0.6312 | Batchsize: 256 Epoch: 014/040 Train ACC: 76.89 | Validation ACC: 69.80 Time elapsed: 2.97 min Epoch: 015/040 | Batch 000/188 | Cost: 0.6457 | Batchsize: 256 Epoch: 015/040 | Batch 150/188 | Cost: 0.7908 | Batchsize: 256 Epoch: 015/040 Train ACC: 78.62 | Validation ACC: 71.50 Time elapsed: 3.18 min Epoch: 016/040 | Batch 000/188 | Cost: 0.7273 | Batchsize: 256 Epoch: 016/040 | Batch 150/188 | Cost: 0.5583 | Batchsize: 256 Epoch: 016/040 Train ACC: 80.89 | Validation ACC: 70.75 Time elapsed: 3.39 min Epoch: 017/040 | Batch 000/188 | Cost: 0.5611 | Batchsize: 256 Epoch: 017/040 | Batch 150/188 | Cost: 0.5131 | Batchsize: 256 Epoch: 017/040 Train ACC: 83.01 | Validation ACC: 71.25 Time elapsed: 3.60 min Epoch: 018/040 | Batch 000/188 | Cost: 0.5365 | Batchsize: 256 Epoch: 018/040 | Batch 150/188 | Cost: 0.4436 | Batchsize: 256 Epoch: 018/040 Train ACC: 81.85 | Validation ACC: 71.55 Time elapsed: 3.81 min Epoch: 019/040 | Batch 000/188 | Cost: 0.4803 | Batchsize: 256 Epoch: 019/040 | Batch 150/188 | Cost: 0.4372 | Batchsize: 256 Epoch: 019/040 Train ACC: 84.76 | Validation ACC: 73.60 Time elapsed: 4.03 min Epoch: 020/040 | Batch 000/188 | Cost: 0.4905 | Batchsize: 256 Epoch: 020/040 | Batch 150/188 | Cost: 0.4021 | Batchsize: 256 Epoch: 020/040 Train ACC: 85.10 | Validation ACC: 71.25 Time elapsed: 4.24 min Epoch: 021/040 | Batch 000/188 | Cost: 0.4978 | Batchsize: 256 Epoch: 021/040 | Batch 150/188 | Cost: 0.4828 | Batchsize: 256 Epoch: 021/040 Train ACC: 87.19 | Validation ACC: 72.75 Time elapsed: 4.45 min Epoch: 022/040 | Batch 000/188 | Cost: 0.3978 | Batchsize: 256 Epoch: 022/040 | Batch 150/188 | Cost: 0.4588 | Batchsize: 256 Epoch: 022/040 Train ACC: 87.93 | Validation ACC: 72.20 Time elapsed: 4.66 min Epoch: 023/040 | Batch 000/188 | Cost: 0.3476 | Batchsize: 256 Epoch: 023/040 | Batch 150/188 | Cost: 0.3774 | Batchsize: 256 Epoch: 023/040 Train ACC: 90.10 | Validation ACC: 72.35 Time elapsed: 4.87 min Epoch: 024/040 | Batch 000/188 | Cost: 0.3039 | Batchsize: 256 Epoch: 024/040 | Batch 150/188 | Cost: 0.4589 | Batchsize: 256 Epoch: 024/040 Train ACC: 89.20 | Validation ACC: 72.00 Time elapsed: 5.09 min Epoch: 025/040 | Batch 000/188 | Cost: 0.2648 | Batchsize: 768 Epoch: 025/040 | Batch 150/188 | Cost: 0.3186 | Batchsize: 768 Epoch: 025/040 Train ACC: 91.24 | Validation ACC: 72.55 Time elapsed: 5.30 min Epoch: 026/040 | Batch 000/188 | Cost: 0.2093 | Batchsize: 768 Epoch: 026/040 | Batch 150/188 | Cost: 0.3252 | Batchsize: 768 Epoch: 026/040 Train ACC: 90.77 | Validation ACC: 71.80 Time elapsed: 5.51 min Epoch: 027/040 | Batch 000/188 | Cost: 0.3375 | Batchsize: 768 Epoch: 027/040 | Batch 150/188 | Cost: 0.2307 | Batchsize: 768 Epoch: 027/040 Train ACC: 92.61 | Validation ACC: 73.15 Time elapsed: 5.72 min Epoch: 028/040 | Batch 000/188 | Cost: 0.2307 | Batchsize: 768 Epoch: 028/040 | Batch 150/188 | Cost: 0.2596 | Batchsize: 768 Epoch: 028/040 Train ACC: 90.78 | Validation ACC: 70.25 Time elapsed: 5.94 min Epoch: 029/040 | Batch 000/063 | Cost: 0.2773 | Batchsize: 1280 Epoch: 029/040 Train ACC: 96.33 | Validation ACC: 75.60 Time elapsed: 6.11 min Epoch: 030/040 | Batch 000/063 | Cost: 0.0958 | Batchsize: 1280 Epoch: 030/040 Train ACC: 96.87 | Validation ACC: 74.95 Time elapsed: 6.28 min Epoch: 031/040 | Batch 000/063 | Cost: 0.1020 | Batchsize: 1280 Epoch: 031/040 Train ACC: 97.30 | Validation ACC: 74.40 Time elapsed: 6.44 min Epoch: 032/040 | Batch 000/063 | Cost: 0.0750 | Batchsize: 1280 Epoch: 032/040 Train ACC: 97.54 | Validation ACC: 75.00 Time elapsed: 6.61 min Epoch: 033/040 | Batch 000/038 | Cost: 0.0687 | Batchsize: 1792 Epoch: 033/040 Train ACC: 98.05 | Validation ACC: 76.20 Time elapsed: 6.79 min Epoch: 034/040 | Batch 000/038 | Cost: 0.0607 | Batchsize: 1792 Epoch: 034/040 Train ACC: 98.19 | Validation ACC: 75.25 Time elapsed: 6.96 min Epoch: 035/040 | Batch 000/038 | Cost: 0.0577 | Batchsize: 1792 Epoch: 035/040 Train ACC: 98.34 | Validation ACC: 75.00 Time elapsed: 7.13 min Epoch: 036/040 | Batch 000/038 | Cost: 0.0546 | Batchsize: 1792 Epoch: 036/040 Train ACC: 98.30 | Validation ACC: 75.35 Time elapsed: 7.30 min Epoch: 037/040 | Batch 000/027 | Cost: 0.0610 | Batchsize: 2304 Epoch: 037/040 Train ACC: 98.56 | Validation ACC: 75.15 Time elapsed: 7.47 min Epoch: 038/040 | Batch 000/027 | Cost: 0.0544 | Batchsize: 2304 Epoch: 038/040 Train ACC: 98.78 | Validation ACC: 75.30 Time elapsed: 7.64 min Epoch: 039/040 | Batch 000/027 | Cost: 0.0431 | Batchsize: 2304 Epoch: 039/040 Train ACC: 98.84 | Validation ACC: 76.75 Time elapsed: 7.81 min Epoch: 040/040 | Batch 000/027 | Cost: 0.0455 | Batchsize: 2304 Epoch: 040/040 Train ACC: 98.84 | Validation ACC: 74.80 Time elapsed: 7.98 min Total Training Time: 7.98 min
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
test_acc = compute_acc(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_acc(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 75.25% Test ACC: 73.93%