Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Author: Sebastian Raschka Python implementation: CPython Python version : 3.8.12 IPython version : 8.0.1 torch: 1.10.1
References
import os
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
I recommend using a function like the following one prior to using dataset loaders and initializing a model if you want to ensure the data is shuffled in the same manner if you rerun this notebook and the model gets the same initial random weights:
def set_all_seeds(seed):
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Similar to the set_all_seeds
function above, I recommend setting the behavior of PyTorch and cuDNN to deterministic (this is particulary relevant when using GPUs). We can also define a function for that:
def set_deterministic():
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)
##########################
### SETTINGS
##########################
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 256
NUM_EPOCHS = 40
# Architecture
NUM_CLASSES = 10
# Other
DEVICE = "cuda:0"
set_all_seeds(RANDOM_SEED)
# Deterministic behavior not yet supported by AdaptiveAvgPool2d
#set_deterministic()
import sys
sys.path.insert(0, "..") # to include ../helper_evaluate.py etc.
from helper_evaluate import compute_accuracy
from helper_data import get_dataloaders_cifar10
from helper_train import train_classifier_simple_v1
### Set random seed ###
set_all_seeds(RANDOM_SEED)
##########################
### Dataset
##########################
train_transforms = transforms.Compose([transforms.Resize((70, 70)),
transforms.RandomCrop((64, 64)),
transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.Resize((70, 70)),
transforms.CenterCrop((64, 64)),
transforms.ToTensor()])
train_loader, valid_loader, test_loader = get_dataloaders_cifar10(
batch_size=BATCH_SIZE,
num_workers=2,
train_transforms=train_transforms,
test_transforms=test_transforms,
validation_fraction=0.1)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
Extracting data/cifar-10-python.tar.gz to data
# Checking the dataset
print('Training Set:\n')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
# Checking the dataset
print('\nValidation Set:')
for images, labels in valid_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
# Checking the dataset
print('\nTesting Set:')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
Training Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256]) tensor([0, 2, 3, 5, 4, 8, 9, 6, 9, 7]) Validation Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256]) tensor([6, 9, 3, 5, 7, 3, 4, 1, 8, 0]) Testing Set: Image batch dimensions: torch.Size([256, 3, 64, 64]) Image label dimensions: torch.Size([256]) tensor([2, 6, 3, 1, 1, 1, 1, 2, 4, 8])
##########################
### MODEL
##########################
class AlexNet(nn.Module):
def __init__(self, num_classes):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
logits = self.classifier(x)
probas = F.softmax(logits, dim=1)
return logits
torch.manual_seed(RANDOM_SEED)
model = AlexNet(NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
log_dict = train_classifier_simple_v1(num_epochs=NUM_EPOCHS, model=model,
optimizer=optimizer, device=DEVICE,
train_loader=train_loader, valid_loader=valid_loader,
logging_interval=50)
Epoch: 001/040 | Batch 0000/0175 | Loss: 2.3033 Epoch: 001/040 | Batch 0050/0175 | Loss: 2.0240 Epoch: 001/040 | Batch 0100/0175 | Loss: 1.9445 Epoch: 001/040 | Batch 0150/0175 | Loss: 1.8135 ***Epoch: 001/040 | Train. Acc.: 33.674% | Loss: 1.703 ***Epoch: 001/040 | Valid. Acc.: 34.880% | Loss: 1.670 Time elapsed: 1.05 min Epoch: 002/040 | Batch 0000/0175 | Loss: 1.7606 Epoch: 002/040 | Batch 0050/0175 | Loss: 1.5473 Epoch: 002/040 | Batch 0100/0175 | Loss: 1.5496 Epoch: 002/040 | Batch 0150/0175 | Loss: 1.5093 ***Epoch: 002/040 | Train. Acc.: 42.819% | Loss: 1.505 ***Epoch: 002/040 | Valid. Acc.: 43.840% | Loss: 1.491 Time elapsed: 2.09 min Epoch: 003/040 | Batch 0000/0175 | Loss: 1.5411 Epoch: 003/040 | Batch 0050/0175 | Loss: 1.5485 Epoch: 003/040 | Batch 0100/0175 | Loss: 1.3723 Epoch: 003/040 | Batch 0150/0175 | Loss: 1.3084 ***Epoch: 003/040 | Train. Acc.: 49.712% | Loss: 1.336 ***Epoch: 003/040 | Valid. Acc.: 50.300% | Loss: 1.327 Time elapsed: 3.12 min Epoch: 004/040 | Batch 0000/0175 | Loss: 1.4301 Epoch: 004/040 | Batch 0050/0175 | Loss: 1.4117 Epoch: 004/040 | Batch 0100/0175 | Loss: 1.2894 Epoch: 004/040 | Batch 0150/0175 | Loss: 1.1508 ***Epoch: 004/040 | Train. Acc.: 54.138% | Loss: 1.231 ***Epoch: 004/040 | Valid. Acc.: 54.300% | Loss: 1.226 Time elapsed: 4.16 min Epoch: 005/040 | Batch 0000/0175 | Loss: 1.1781 Epoch: 005/040 | Batch 0050/0175 | Loss: 1.2942 Epoch: 005/040 | Batch 0100/0175 | Loss: 1.3343 Epoch: 005/040 | Batch 0150/0175 | Loss: 1.1216 ***Epoch: 005/040 | Train. Acc.: 58.536% | Loss: 1.139 ***Epoch: 005/040 | Valid. Acc.: 58.220% | Loss: 1.152 Time elapsed: 5.23 min Epoch: 006/040 | Batch 0000/0175 | Loss: 1.1030 Epoch: 006/040 | Batch 0050/0175 | Loss: 1.1732 Epoch: 006/040 | Batch 0100/0175 | Loss: 1.1508 Epoch: 006/040 | Batch 0150/0175 | Loss: 1.0059 ***Epoch: 006/040 | Train. Acc.: 58.882% | Loss: 1.132 ***Epoch: 006/040 | Valid. Acc.: 58.600% | Loss: 1.158 Time elapsed: 6.28 min Epoch: 007/040 | Batch 0000/0175 | Loss: 1.0091 Epoch: 007/040 | Batch 0050/0175 | Loss: 1.2888 Epoch: 007/040 | Batch 0100/0175 | Loss: 1.0148 Epoch: 007/040 | Batch 0150/0175 | Loss: 1.0491 ***Epoch: 007/040 | Train. Acc.: 65.203% | Loss: 0.966 ***Epoch: 007/040 | Valid. Acc.: 63.880% | Loss: 1.007 Time elapsed: 7.31 min Epoch: 008/040 | Batch 0000/0175 | Loss: 0.8920 Epoch: 008/040 | Batch 0050/0175 | Loss: 0.9769 Epoch: 008/040 | Batch 0100/0175 | Loss: 1.0159 Epoch: 008/040 | Batch 0150/0175 | Loss: 1.0733 ***Epoch: 008/040 | Train. Acc.: 67.181% | Loss: 0.920 ***Epoch: 008/040 | Valid. Acc.: 65.020% | Loss: 0.974 Time elapsed: 8.35 min Epoch: 009/040 | Batch 0000/0175 | Loss: 0.9276 Epoch: 009/040 | Batch 0050/0175 | Loss: 0.8630 Epoch: 009/040 | Batch 0100/0175 | Loss: 1.1130 Epoch: 009/040 | Batch 0150/0175 | Loss: 0.9105 ***Epoch: 009/040 | Train. Acc.: 66.795% | Loss: 0.920 ***Epoch: 009/040 | Valid. Acc.: 64.980% | Loss: 0.984 Time elapsed: 9.38 min Epoch: 010/040 | Batch 0000/0175 | Loss: 0.8506 Epoch: 010/040 | Batch 0050/0175 | Loss: 0.7531 Epoch: 010/040 | Batch 0100/0175 | Loss: 0.9312 Epoch: 010/040 | Batch 0150/0175 | Loss: 0.9103 ***Epoch: 010/040 | Train. Acc.: 70.491% | Loss: 0.832 ***Epoch: 010/040 | Valid. Acc.: 67.560% | Loss: 0.934 Time elapsed: 10.42 min Epoch: 011/040 | Batch 0000/0175 | Loss: 0.8196 Epoch: 011/040 | Batch 0050/0175 | Loss: 0.7955 Epoch: 011/040 | Batch 0100/0175 | Loss: 0.9367 Epoch: 011/040 | Batch 0150/0175 | Loss: 0.7501 ***Epoch: 011/040 | Train. Acc.: 70.826% | Loss: 0.819 ***Epoch: 011/040 | Valid. Acc.: 66.220% | Loss: 0.950 Time elapsed: 11.50 min Epoch: 012/040 | Batch 0000/0175 | Loss: 0.7863 Epoch: 012/040 | Batch 0050/0175 | Loss: 0.8496 Epoch: 012/040 | Batch 0100/0175 | Loss: 0.7997 Epoch: 012/040 | Batch 0150/0175 | Loss: 0.9733 ***Epoch: 012/040 | Train. Acc.: 73.600% | Loss: 0.757 ***Epoch: 012/040 | Valid. Acc.: 68.640% | Loss: 0.901 Time elapsed: 12.72 min Epoch: 013/040 | Batch 0000/0175 | Loss: 0.8286 Epoch: 013/040 | Batch 0050/0175 | Loss: 0.8397 Epoch: 013/040 | Batch 0100/0175 | Loss: 0.7478 Epoch: 013/040 | Batch 0150/0175 | Loss: 0.8451 ***Epoch: 013/040 | Train. Acc.: 76.750% | Loss: 0.672 ***Epoch: 013/040 | Valid. Acc.: 70.260% | Loss: 0.848 Time elapsed: 13.96 min Epoch: 014/040 | Batch 0000/0175 | Loss: 0.6818 Epoch: 014/040 | Batch 0050/0175 | Loss: 0.7883 Epoch: 014/040 | Batch 0100/0175 | Loss: 0.7845 Epoch: 014/040 | Batch 0150/0175 | Loss: 0.6714 ***Epoch: 014/040 | Train. Acc.: 76.462% | Loss: 0.669 ***Epoch: 014/040 | Valid. Acc.: 69.560% | Loss: 0.876 Time elapsed: 15.17 min Epoch: 015/040 | Batch 0000/0175 | Loss: 0.7720 Epoch: 015/040 | Batch 0050/0175 | Loss: 0.7569 Epoch: 015/040 | Batch 0100/0175 | Loss: 0.6428 Epoch: 015/040 | Batch 0150/0175 | Loss: 0.7415 ***Epoch: 015/040 | Train. Acc.: 78.196% | Loss: 0.622 ***Epoch: 015/040 | Valid. Acc.: 70.460% | Loss: 0.852 Time elapsed: 16.39 min Epoch: 016/040 | Batch 0000/0175 | Loss: 0.6150 Epoch: 016/040 | Batch 0050/0175 | Loss: 0.7300 Epoch: 016/040 | Batch 0100/0175 | Loss: 0.4870 Epoch: 016/040 | Batch 0150/0175 | Loss: 0.6177 ***Epoch: 016/040 | Train. Acc.: 80.033% | Loss: 0.571 ***Epoch: 016/040 | Valid. Acc.: 71.500% | Loss: 0.832 Time elapsed: 17.62 min Epoch: 017/040 | Batch 0000/0175 | Loss: 0.6556 Epoch: 017/040 | Batch 0050/0175 | Loss: 0.6564 Epoch: 017/040 | Batch 0100/0175 | Loss: 0.5505 Epoch: 017/040 | Batch 0150/0175 | Loss: 0.6272 ***Epoch: 017/040 | Train. Acc.: 81.415% | Loss: 0.532 ***Epoch: 017/040 | Valid. Acc.: 71.980% | Loss: 0.836 Time elapsed: 18.80 min Epoch: 018/040 | Batch 0000/0175 | Loss: 0.5772 Epoch: 018/040 | Batch 0050/0175 | Loss: 0.4951 Epoch: 018/040 | Batch 0100/0175 | Loss: 0.4850 Epoch: 018/040 | Batch 0150/0175 | Loss: 0.6942 ***Epoch: 018/040 | Train. Acc.: 82.944% | Loss: 0.486 ***Epoch: 018/040 | Valid. Acc.: 71.520% | Loss: 0.839 Time elapsed: 19.84 min Epoch: 019/040 | Batch 0000/0175 | Loss: 0.4757 Epoch: 019/040 | Batch 0050/0175 | Loss: 0.4909 Epoch: 019/040 | Batch 0100/0175 | Loss: 0.5568 Epoch: 019/040 | Batch 0150/0175 | Loss: 0.5895 ***Epoch: 019/040 | Train. Acc.: 81.592% | Loss: 0.515 ***Epoch: 019/040 | Valid. Acc.: 70.840% | Loss: 0.911 Time elapsed: 20.87 min Epoch: 020/040 | Batch 0000/0175 | Loss: 0.5108 Epoch: 020/040 | Batch 0050/0175 | Loss: 0.5133 Epoch: 020/040 | Batch 0100/0175 | Loss: 0.4775 Epoch: 020/040 | Batch 0150/0175 | Loss: 0.5364 ***Epoch: 020/040 | Train. Acc.: 85.272% | Loss: 0.431 ***Epoch: 020/040 | Valid. Acc.: 72.240% | Loss: 0.850 Time elapsed: 21.89 min Epoch: 021/040 | Batch 0000/0175 | Loss: 0.4184 Epoch: 021/040 | Batch 0050/0175 | Loss: 0.5490 Epoch: 021/040 | Batch 0100/0175 | Loss: 0.4124 Epoch: 021/040 | Batch 0150/0175 | Loss: 0.3877 ***Epoch: 021/040 | Train. Acc.: 86.616% | Loss: 0.384 ***Epoch: 021/040 | Valid. Acc.: 72.900% | Loss: 0.850 Time elapsed: 22.93 min Epoch: 022/040 | Batch 0000/0175 | Loss: 0.3587 Epoch: 022/040 | Batch 0050/0175 | Loss: 0.4164 Epoch: 022/040 | Batch 0100/0175 | Loss: 0.4908 Epoch: 022/040 | Batch 0150/0175 | Loss: 0.5300 ***Epoch: 022/040 | Train. Acc.: 87.763% | Loss: 0.353 ***Epoch: 022/040 | Valid. Acc.: 73.160% | Loss: 0.867 Time elapsed: 23.97 min Epoch: 023/040 | Batch 0000/0175 | Loss: 0.3409 Epoch: 023/040 | Batch 0050/0175 | Loss: 0.3932 Epoch: 023/040 | Batch 0100/0175 | Loss: 0.4906 Epoch: 023/040 | Batch 0150/0175 | Loss: 0.3842 ***Epoch: 023/040 | Train. Acc.: 88.516% | Loss: 0.325 ***Epoch: 023/040 | Valid. Acc.: 72.940% | Loss: 0.867 Time elapsed: 25.01 min Epoch: 024/040 | Batch 0000/0175 | Loss: 0.3903 Epoch: 024/040 | Batch 0050/0175 | Loss: 0.4127 Epoch: 024/040 | Batch 0100/0175 | Loss: 0.3478 Epoch: 024/040 | Batch 0150/0175 | Loss: 0.4306 ***Epoch: 024/040 | Train. Acc.: 90.315% | Loss: 0.284 ***Epoch: 024/040 | Valid. Acc.: 73.220% | Loss: 0.911 Time elapsed: 26.04 min Epoch: 025/040 | Batch 0000/0175 | Loss: 0.2716 Epoch: 025/040 | Batch 0050/0175 | Loss: 0.3371 Epoch: 025/040 | Batch 0100/0175 | Loss: 0.4309 Epoch: 025/040 | Batch 0150/0175 | Loss: 0.4343 ***Epoch: 025/040 | Train. Acc.: 88.908% | Loss: 0.311 ***Epoch: 025/040 | Valid. Acc.: 73.000% | Loss: 0.909 Time elapsed: 27.07 min Epoch: 026/040 | Batch 0000/0175 | Loss: 0.2467 Epoch: 026/040 | Batch 0050/0175 | Loss: 0.2832 Epoch: 026/040 | Batch 0100/0175 | Loss: 0.3431 Epoch: 026/040 | Batch 0150/0175 | Loss: 0.3218 ***Epoch: 026/040 | Train. Acc.: 90.547% | Loss: 0.272 ***Epoch: 026/040 | Valid. Acc.: 72.900% | Loss: 0.925 Time elapsed: 28.10 min Epoch: 027/040 | Batch 0000/0175 | Loss: 0.3064 Epoch: 027/040 | Batch 0050/0175 | Loss: 0.2874 Epoch: 027/040 | Batch 0100/0175 | Loss: 0.3545 Epoch: 027/040 | Batch 0150/0175 | Loss: 0.3866 ***Epoch: 027/040 | Train. Acc.: 92.277% | Loss: 0.230 ***Epoch: 027/040 | Valid. Acc.: 73.760% | Loss: 0.935 Time elapsed: 29.13 min Epoch: 028/040 | Batch 0000/0175 | Loss: 0.1964 Epoch: 028/040 | Batch 0050/0175 | Loss: 0.2317 Epoch: 028/040 | Batch 0100/0175 | Loss: 0.2595 Epoch: 028/040 | Batch 0150/0175 | Loss: 0.3056 ***Epoch: 028/040 | Train. Acc.: 92.049% | Loss: 0.225 ***Epoch: 028/040 | Valid. Acc.: 73.340% | Loss: 0.994 Time elapsed: 30.16 min Epoch: 029/040 | Batch 0000/0175 | Loss: 0.2118 Epoch: 029/040 | Batch 0050/0175 | Loss: 0.2198 Epoch: 029/040 | Batch 0100/0175 | Loss: 0.2389 Epoch: 029/040 | Batch 0150/0175 | Loss: 0.3052 ***Epoch: 029/040 | Train. Acc.: 93.170% | Loss: 0.198 ***Epoch: 029/040 | Valid. Acc.: 73.520% | Loss: 1.004 Time elapsed: 31.20 min Epoch: 030/040 | Batch 0000/0175 | Loss: 0.1664 Epoch: 030/040 | Batch 0050/0175 | Loss: 0.1880 Epoch: 030/040 | Batch 0100/0175 | Loss: 0.1938 Epoch: 030/040 | Batch 0150/0175 | Loss: 0.2032 ***Epoch: 030/040 | Train. Acc.: 93.333% | Loss: 0.188 ***Epoch: 030/040 | Valid. Acc.: 72.820% | Loss: 1.061 Time elapsed: 32.23 min Epoch: 031/040 | Batch 0000/0175 | Loss: 0.2679 Epoch: 031/040 | Batch 0050/0175 | Loss: 0.2778 Epoch: 031/040 | Batch 0100/0175 | Loss: 0.2026 Epoch: 031/040 | Batch 0150/0175 | Loss: 0.2144 ***Epoch: 031/040 | Train. Acc.: 94.058% | Loss: 0.170 ***Epoch: 031/040 | Valid. Acc.: 73.500% | Loss: 1.044 Time elapsed: 33.27 min Epoch: 032/040 | Batch 0000/0175 | Loss: 0.1634 Epoch: 032/040 | Batch 0050/0175 | Loss: 0.2475 Epoch: 032/040 | Batch 0100/0175 | Loss: 0.1528 Epoch: 032/040 | Batch 0150/0175 | Loss: 0.2810 ***Epoch: 032/040 | Train. Acc.: 94.471% | Loss: 0.161 ***Epoch: 032/040 | Valid. Acc.: 73.000% | Loss: 1.065 Time elapsed: 34.32 min Epoch: 033/040 | Batch 0000/0175 | Loss: 0.2095 Epoch: 033/040 | Batch 0050/0175 | Loss: 0.1590 Epoch: 033/040 | Batch 0100/0175 | Loss: 0.1752 Epoch: 033/040 | Batch 0150/0175 | Loss: 0.2319 ***Epoch: 033/040 | Train. Acc.: 95.118% | Loss: 0.141 ***Epoch: 033/040 | Valid. Acc.: 73.440% | Loss: 1.075 Time elapsed: 35.35 min Epoch: 034/040 | Batch 0000/0175 | Loss: 0.1156 Epoch: 034/040 | Batch 0050/0175 | Loss: 0.1456 Epoch: 034/040 | Batch 0100/0175 | Loss: 0.1519 Epoch: 034/040 | Batch 0150/0175 | Loss: 0.1831 ***Epoch: 034/040 | Train. Acc.: 95.286% | Loss: 0.142 ***Epoch: 034/040 | Valid. Acc.: 73.820% | Loss: 1.064 Time elapsed: 36.40 min Epoch: 035/040 | Batch 0000/0175 | Loss: 0.1532 Epoch: 035/040 | Batch 0050/0175 | Loss: 0.1267 Epoch: 035/040 | Batch 0100/0175 | Loss: 0.1633 Epoch: 035/040 | Batch 0150/0175 | Loss: 0.1263 ***Epoch: 035/040 | Train. Acc.: 94.638% | Loss: 0.154 ***Epoch: 035/040 | Valid. Acc.: 73.440% | Loss: 1.093 Time elapsed: 37.43 min Epoch: 036/040 | Batch 0000/0175 | Loss: 0.1787 Epoch: 036/040 | Batch 0050/0175 | Loss: 0.1622 Epoch: 036/040 | Batch 0100/0175 | Loss: 0.1840 Epoch: 036/040 | Batch 0150/0175 | Loss: 0.1143 ***Epoch: 036/040 | Train. Acc.: 95.480% | Loss: 0.132 ***Epoch: 036/040 | Valid. Acc.: 73.020% | Loss: 1.159 Time elapsed: 38.46 min Epoch: 037/040 | Batch 0000/0175 | Loss: 0.1282 Epoch: 037/040 | Batch 0050/0175 | Loss: 0.1299 Epoch: 037/040 | Batch 0100/0175 | Loss: 0.1869 Epoch: 037/040 | Batch 0150/0175 | Loss: 0.1387 ***Epoch: 037/040 | Train. Acc.: 95.129% | Loss: 0.138 ***Epoch: 037/040 | Valid. Acc.: 72.640% | Loss: 1.174 Time elapsed: 39.50 min Epoch: 038/040 | Batch 0000/0175 | Loss: 0.1137 Epoch: 038/040 | Batch 0050/0175 | Loss: 0.1053 Epoch: 038/040 | Batch 0100/0175 | Loss: 0.1298 Epoch: 038/040 | Batch 0150/0175 | Loss: 0.1280 ***Epoch: 038/040 | Train. Acc.: 95.429% | Loss: 0.134 ***Epoch: 038/040 | Valid. Acc.: 73.040% | Loss: 1.230 Time elapsed: 40.53 min Epoch: 039/040 | Batch 0000/0175 | Loss: 0.1410 Epoch: 039/040 | Batch 0050/0175 | Loss: 0.1084 Epoch: 039/040 | Batch 0100/0175 | Loss: 0.1578 Epoch: 039/040 | Batch 0150/0175 | Loss: 0.1516 ***Epoch: 039/040 | Train. Acc.: 97.002% | Loss: 0.090 ***Epoch: 039/040 | Valid. Acc.: 73.420% | Loss: 1.159 Time elapsed: 41.57 min Epoch: 040/040 | Batch 0000/0175 | Loss: 0.1143 Epoch: 040/040 | Batch 0050/0175 | Loss: 0.1153 Epoch: 040/040 | Batch 0100/0175 | Loss: 0.1493 Epoch: 040/040 | Batch 0150/0175 | Loss: 0.2771 ***Epoch: 040/040 | Train. Acc.: 96.071% | Loss: 0.111 ***Epoch: 040/040 | Valid. Acc.: 73.520% | Loss: 1.218 Time elapsed: 42.61 min Total Training Time: 42.61 min
import matplotlib.pyplot as plt
%matplotlib inline
loss_list = log_dict['train_loss_per_batch']
plt.plot(loss_list, label='Minibatch loss')
plt.plot(np.convolve(loss_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['train_acc_per_epoch'], label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['valid_acc_per_epoch'], label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
train_acc = compute_accuracy(model=model,
data_loader=test_loader,
device=DEVICE)
test_acc = compute_accuracy(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_accuracy(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Train ACC: {valid_acc:.2f}%')
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Train ACC: 73.52% Validation ACC: 73.52% Test ACC: 72.11%
%watermark -iv
sys : 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51) [GCC 9.4.0] matplotlib : 3.3.4 PIL : 9.0.1 torchvision: 0.11.2 numpy : 1.22.0 torch : 1.10.1 pandas : 1.4.1