Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 torch 1.2.0
Implementing a very basic graph neural network (GNN) using a subnetwork for edge prediction.
Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1].
In the related notebook, [gnn-basic-1.ipyb], the adjacency matrix of the pixels was basically just determined by the neighborhood pixels. Using a Gaussian filter, pixels were connected based on their Euclidean distance in the grid. In this notebook, the edges are predicted via a seperate neural network model
self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Tanh())
Using the resulting adjacency matrix A, we can compute the output of a layer as
X(l+1)=AX(l)W(l).Here, A is the N×N adjacency matrix, and X is the N×C feature matrix (a 2D coordinate array, where N is the total number of pixels -- 28×28=784 in MNIST). W is the weight matrix of shape N×P, where P would represent the number of classes if we have only a single hidden layer.
import time
import numpy as np
from scipy.spatial.distance import cdist
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
%matplotlib inline
import matplotlib.pyplot as plt
##########################
### SETTINGS
##########################
# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0005
NUM_EPOCHS = 50
BATCH_SIZE = 128
IMG_SIZE = 28
# Architecture
NUM_CLASSES = 10
train_indices = torch.arange(0, 59000)
valid_indices = torch.arange(59000, 60000)
custom_transform = transforms.Compose([transforms.ToTensor()])
train_and_valid = datasets.MNIST(root='data',
train=True,
transform=custom_transform,
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=custom_transform,
download=True)
train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Image batch dimensions: torch.Size([128, 1, 28, 28]) Image label dimensions: torch.Size([128])
##########################
### MODEL
##########################
def make_coordinate_array(img_size, out_size=4):
### Make 2D coordinate array (for MNIST: 784x2)
n_rows = img_size * img_size
col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
coord = np.stack((col, row), axis=2).reshape(-1, 2)
coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)
coord = torch.from_numpy(coord).float()
### Reshape to [N, N, out_size]
coord = torch.cat((coord.unsqueeze(0).repeat(n_rows, 1, int(out_size/2-1)),
coord.unsqueeze(1).repeat(1, n_rows, 1)), dim=2)
return coord
class GraphNet(nn.Module):
def __init__(self, img_size=28, coord_features=4, num_classes=10):
super(GraphNet, self).__init__()
n_rows = img_size**2
self.fc = nn.Linear(n_rows, num_classes, bias=False)
coord = make_coordinate_array(img_size, coord_features)
self.register_buffer('coord', coord)
##########
# Edge Predictor
self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 32), # coord -> hidden
nn.ReLU(),
nn.Linear(32, 1), # hidden -> edge
nn.Tanh())
def forward(self, x):
B = x.size(0)
### Predict edges
self.A = self.pred_edge_fc(self.coord).squeeze()
### Reshape Adjacency Matrix
# [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW
A_tensor = self.A.unsqueeze(0)
# [1, N, N] Adj tensor -> [B, N, N] tensor
A_tensor = self.A.expand(B, -1, -1)
### Reshape inputs
# [B, C, H, W] => [B, H*W, 1]
x_reshape = x.view(B, -1, 1)
# bmm = batch matrix product to sum the neighbor features
# Input: [B, N, N] x [B, N, 1]
# Output: [B, N]
avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1))
logits = self.fc(avg_neighbor_features)
probas = F.softmax(logits, dim=1)
return logits, probas
torch.manual_seed(RANDOM_SEED)
model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def compute_acc(model, data_loader, device):
correct_pred, num_examples = 0, 0
for features, targets in data_loader:
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
start_time = time.time()
cost_list = []
train_acc_list, valid_acc_list = [], []
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
#################################################
### CODE ONLY FOR LOGGING BEYOND THIS POINT
################################################
cost_list.append(cost.item())
if not batch_idx % 150:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f}')
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
train_acc = compute_acc(model, train_loader, device=DEVICE)
valid_acc = compute_acc(model, valid_loader, device=DEVICE)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/050 | Batch 000/461 | Cost: 24.2727 Epoch: 001/050 | Batch 150/461 | Cost: 2.2706 Epoch: 001/050 | Batch 300/461 | Cost: 1.8713 Epoch: 001/050 | Batch 450/461 | Cost: 1.5048 Epoch: 001/050 Train ACC: 50.39 | Validation ACC: 54.80 Time elapsed: 0.25 min Epoch: 002/050 | Batch 000/461 | Cost: 1.4445 Epoch: 002/050 | Batch 150/461 | Cost: 1.3288 Epoch: 002/050 | Batch 300/461 | Cost: 1.1868 Epoch: 002/050 | Batch 450/461 | Cost: 1.2040 Epoch: 002/050 Train ACC: 67.68 | Validation ACC: 71.40 Time elapsed: 0.49 min Epoch: 003/050 | Batch 000/461 | Cost: 1.2128 Epoch: 003/050 | Batch 150/461 | Cost: 0.9953 Epoch: 003/050 | Batch 300/461 | Cost: 0.9818 Epoch: 003/050 | Batch 450/461 | Cost: 1.0487 Epoch: 003/050 Train ACC: 68.09 | Validation ACC: 73.40 Time elapsed: 0.74 min Epoch: 004/050 | Batch 000/461 | Cost: 1.0444 Epoch: 004/050 | Batch 150/461 | Cost: 0.9064 Epoch: 004/050 | Batch 300/461 | Cost: 0.9152 Epoch: 004/050 | Batch 450/461 | Cost: 0.7396 Epoch: 004/050 Train ACC: 76.10 | Validation ACC: 80.20 Time elapsed: 0.98 min Epoch: 005/050 | Batch 000/461 | Cost: 0.7698 Epoch: 005/050 | Batch 150/461 | Cost: 0.8356 Epoch: 005/050 | Batch 300/461 | Cost: 0.6544 Epoch: 005/050 | Batch 450/461 | Cost: 0.8700 Epoch: 005/050 Train ACC: 80.30 | Validation ACC: 84.10 Time elapsed: 1.22 min Epoch: 006/050 | Batch 000/461 | Cost: 0.6292 Epoch: 006/050 | Batch 150/461 | Cost: 0.7779 Epoch: 006/050 | Batch 300/461 | Cost: 0.5978 Epoch: 006/050 | Batch 450/461 | Cost: 0.6260 Epoch: 006/050 Train ACC: 82.00 | Validation ACC: 85.60 Time elapsed: 1.46 min Epoch: 007/050 | Batch 000/461 | Cost: 0.7172 Epoch: 007/050 | Batch 150/461 | Cost: 0.6444 Epoch: 007/050 | Batch 300/461 | Cost: 0.5620 Epoch: 007/050 | Batch 450/461 | Cost: 0.5314 Epoch: 007/050 Train ACC: 82.90 | Validation ACC: 86.20 Time elapsed: 1.71 min Epoch: 008/050 | Batch 000/461 | Cost: 0.6211 Epoch: 008/050 | Batch 150/461 | Cost: 0.5004 Epoch: 008/050 | Batch 300/461 | Cost: 0.5274 Epoch: 008/050 | Batch 450/461 | Cost: 0.6611 Epoch: 008/050 Train ACC: 84.86 | Validation ACC: 87.90 Time elapsed: 1.95 min Epoch: 009/050 | Batch 000/461 | Cost: 0.4017 Epoch: 009/050 | Batch 150/461 | Cost: 0.7080 Epoch: 009/050 | Batch 300/461 | Cost: 0.4298 Epoch: 009/050 | Batch 450/461 | Cost: 0.4516 Epoch: 009/050 Train ACC: 85.89 | Validation ACC: 90.20 Time elapsed: 2.19 min Epoch: 010/050 | Batch 000/461 | Cost: 0.4571 Epoch: 010/050 | Batch 150/461 | Cost: 0.4976 Epoch: 010/050 | Batch 300/461 | Cost: 0.6208 Epoch: 010/050 | Batch 450/461 | Cost: 0.3780 Epoch: 010/050 Train ACC: 85.84 | Validation ACC: 89.40 Time elapsed: 2.43 min Epoch: 011/050 | Batch 000/461 | Cost: 0.5262 Epoch: 011/050 | Batch 150/461 | Cost: 0.4255 Epoch: 011/050 | Batch 300/461 | Cost: 0.3840 Epoch: 011/050 | Batch 450/461 | Cost: 0.4941 Epoch: 011/050 Train ACC: 85.73 | Validation ACC: 90.40 Time elapsed: 2.68 min Epoch: 012/050 | Batch 000/461 | Cost: 0.3425 Epoch: 012/050 | Batch 150/461 | Cost: 0.5059 Epoch: 012/050 | Batch 300/461 | Cost: 0.6590 Epoch: 012/050 | Batch 450/461 | Cost: 0.5481 Epoch: 012/050 Train ACC: 87.50 | Validation ACC: 91.30 Time elapsed: 2.92 min Epoch: 013/050 | Batch 000/461 | Cost: 0.6081 Epoch: 013/050 | Batch 150/461 | Cost: 0.4584 Epoch: 013/050 | Batch 300/461 | Cost: 0.2856 Epoch: 013/050 | Batch 450/461 | Cost: 0.4324 Epoch: 013/050 Train ACC: 87.35 | Validation ACC: 91.20 Time elapsed: 3.17 min Epoch: 014/050 | Batch 000/461 | Cost: 0.4685 Epoch: 014/050 | Batch 150/461 | Cost: 0.4492 Epoch: 014/050 | Batch 300/461 | Cost: 0.3913 Epoch: 014/050 | Batch 450/461 | Cost: 0.5154 Epoch: 014/050 Train ACC: 86.71 | Validation ACC: 91.20 Time elapsed: 3.41 min Epoch: 015/050 | Batch 000/461 | Cost: 0.4526 Epoch: 015/050 | Batch 150/461 | Cost: 0.4834 Epoch: 015/050 | Batch 300/461 | Cost: 0.5208 Epoch: 015/050 | Batch 450/461 | Cost: 0.3536 Epoch: 015/050 Train ACC: 85.21 | Validation ACC: 89.50 Time elapsed: 3.66 min Epoch: 016/050 | Batch 000/461 | Cost: 0.6614 Epoch: 016/050 | Batch 150/461 | Cost: 0.3036 Epoch: 016/050 | Batch 300/461 | Cost: 0.3766 Epoch: 016/050 | Batch 450/461 | Cost: 0.4550 Epoch: 016/050 Train ACC: 86.97 | Validation ACC: 92.10 Time elapsed: 3.91 min Epoch: 017/050 | Batch 000/461 | Cost: 0.6241 Epoch: 017/050 | Batch 150/461 | Cost: 0.3934 Epoch: 017/050 | Batch 300/461 | Cost: 0.4330 Epoch: 017/050 | Batch 450/461 | Cost: 0.5914 Epoch: 017/050 Train ACC: 88.12 | Validation ACC: 91.60 Time elapsed: 4.15 min Epoch: 018/050 | Batch 000/461 | Cost: 0.3769 Epoch: 018/050 | Batch 150/461 | Cost: 0.4817 Epoch: 018/050 | Batch 300/461 | Cost: 0.4103 Epoch: 018/050 | Batch 450/461 | Cost: 0.3727 Epoch: 018/050 Train ACC: 86.58 | Validation ACC: 90.90 Time elapsed: 4.40 min Epoch: 019/050 | Batch 000/461 | Cost: 0.4098 Epoch: 019/050 | Batch 150/461 | Cost: 0.4435 Epoch: 019/050 | Batch 300/461 | Cost: 0.2952 Epoch: 019/050 | Batch 450/461 | Cost: 0.3328 Epoch: 019/050 Train ACC: 88.65 | Validation ACC: 92.00 Time elapsed: 4.64 min Epoch: 020/050 | Batch 000/461 | Cost: 0.5363 Epoch: 020/050 | Batch 150/461 | Cost: 0.3143 Epoch: 020/050 | Batch 300/461 | Cost: 0.5186 Epoch: 020/050 | Batch 450/461 | Cost: 0.3806 Epoch: 020/050 Train ACC: 88.95 | Validation ACC: 92.70 Time elapsed: 4.89 min Epoch: 021/050 | Batch 000/461 | Cost: 0.3810 Epoch: 021/050 | Batch 150/461 | Cost: 0.2470 Epoch: 021/050 | Batch 300/461 | Cost: 0.6154 Epoch: 021/050 | Batch 450/461 | Cost: 0.3651 Epoch: 021/050 Train ACC: 88.31 | Validation ACC: 92.40 Time elapsed: 5.13 min Epoch: 022/050 | Batch 000/461 | Cost: 0.3704 Epoch: 022/050 | Batch 150/461 | Cost: 0.4338 Epoch: 022/050 | Batch 300/461 | Cost: 0.4197 Epoch: 022/050 | Batch 450/461 | Cost: 0.3304 Epoch: 022/050 Train ACC: 88.62 | Validation ACC: 91.90 Time elapsed: 5.31 min Epoch: 023/050 | Batch 000/461 | Cost: 0.2825 Epoch: 023/050 | Batch 150/461 | Cost: 0.4302 Epoch: 023/050 | Batch 300/461 | Cost: 0.4738 Epoch: 023/050 | Batch 450/461 | Cost: 0.4362 Epoch: 023/050 Train ACC: 89.02 | Validation ACC: 92.80 Time elapsed: 5.44 min Epoch: 024/050 | Batch 000/461 | Cost: 0.2097 Epoch: 024/050 | Batch 150/461 | Cost: 0.4440 Epoch: 024/050 | Batch 300/461 | Cost: 0.4467 Epoch: 024/050 | Batch 450/461 | Cost: 0.2744 Epoch: 024/050 Train ACC: 88.82 | Validation ACC: 92.40 Time elapsed: 5.57 min Epoch: 025/050 | Batch 000/461 | Cost: 0.2734 Epoch: 025/050 | Batch 150/461 | Cost: 0.3980 Epoch: 025/050 | Batch 300/461 | Cost: 0.4395 Epoch: 025/050 | Batch 450/461 | Cost: 0.2336 Epoch: 025/050 Train ACC: 89.59 | Validation ACC: 93.90 Time elapsed: 5.70 min Epoch: 026/050 | Batch 000/461 | Cost: 0.3138 Epoch: 026/050 | Batch 150/461 | Cost: 0.3772 Epoch: 026/050 | Batch 300/461 | Cost: 0.2955 Epoch: 026/050 | Batch 450/461 | Cost: 0.3747 Epoch: 026/050 Train ACC: 88.71 | Validation ACC: 92.70 Time elapsed: 5.82 min Epoch: 027/050 | Batch 000/461 | Cost: 0.4107 Epoch: 027/050 | Batch 150/461 | Cost: 0.4375 Epoch: 027/050 | Batch 300/461 | Cost: 0.3802 Epoch: 027/050 | Batch 450/461 | Cost: 0.3240 Epoch: 027/050 Train ACC: 87.90 | Validation ACC: 91.60 Time elapsed: 5.95 min Epoch: 028/050 | Batch 000/461 | Cost: 0.5124 Epoch: 028/050 | Batch 150/461 | Cost: 0.4980 Epoch: 028/050 | Batch 300/461 | Cost: 0.3937 Epoch: 028/050 | Batch 450/461 | Cost: 0.2704 Epoch: 028/050 Train ACC: 89.08 | Validation ACC: 92.30 Time elapsed: 6.08 min Epoch: 029/050 | Batch 000/461 | Cost: 0.3328 Epoch: 029/050 | Batch 150/461 | Cost: 0.3022 Epoch: 029/050 | Batch 300/461 | Cost: 0.3222 Epoch: 029/050 | Batch 450/461 | Cost: 0.3084 Epoch: 029/050 Train ACC: 89.30 | Validation ACC: 93.90 Time elapsed: 6.21 min Epoch: 030/050 | Batch 000/461 | Cost: 0.4667 Epoch: 030/050 | Batch 150/461 | Cost: 0.3290 Epoch: 030/050 | Batch 300/461 | Cost: 0.3261 Epoch: 030/050 | Batch 450/461 | Cost: 0.3347 Epoch: 030/050 Train ACC: 89.17 | Validation ACC: 93.60 Time elapsed: 6.33 min Epoch: 031/050 | Batch 000/461 | Cost: 0.3486 Epoch: 031/050 | Batch 150/461 | Cost: 0.2426 Epoch: 031/050 | Batch 300/461 | Cost: 0.2748 Epoch: 031/050 | Batch 450/461 | Cost: 0.2072 Epoch: 031/050 Train ACC: 89.17 | Validation ACC: 93.20 Time elapsed: 6.46 min Epoch: 032/050 | Batch 000/461 | Cost: 0.3423 Epoch: 032/050 | Batch 150/461 | Cost: 0.4924 Epoch: 032/050 | Batch 300/461 | Cost: 0.4072 Epoch: 032/050 | Batch 450/461 | Cost: 0.3611 Epoch: 032/050 Train ACC: 89.83 | Validation ACC: 94.30 Time elapsed: 6.59 min Epoch: 033/050 | Batch 000/461 | Cost: 0.2461 Epoch: 033/050 | Batch 150/461 | Cost: 0.2343 Epoch: 033/050 | Batch 300/461 | Cost: 0.2891 Epoch: 033/050 | Batch 450/461 | Cost: 0.3772 Epoch: 033/050 Train ACC: 88.81 | Validation ACC: 92.40 Time elapsed: 6.72 min Epoch: 034/050 | Batch 000/461 | Cost: 0.3052 Epoch: 034/050 | Batch 150/461 | Cost: 0.5129 Epoch: 034/050 | Batch 300/461 | Cost: 0.3810 Epoch: 034/050 | Batch 450/461 | Cost: 0.2906 Epoch: 034/050 Train ACC: 89.34 | Validation ACC: 93.10 Time elapsed: 6.85 min Epoch: 035/050 | Batch 000/461 | Cost: 0.3604 Epoch: 035/050 | Batch 150/461 | Cost: 0.3832 Epoch: 035/050 | Batch 300/461 | Cost: 0.3632 Epoch: 035/050 | Batch 450/461 | Cost: 0.3345 Epoch: 035/050 Train ACC: 89.74 | Validation ACC: 93.10 Time elapsed: 6.98 min Epoch: 036/050 | Batch 000/461 | Cost: 0.3382 Epoch: 036/050 | Batch 150/461 | Cost: 0.3754 Epoch: 036/050 | Batch 300/461 | Cost: 0.4120 Epoch: 036/050 | Batch 450/461 | Cost: 0.4710 Epoch: 036/050 Train ACC: 89.10 | Validation ACC: 93.90 Time elapsed: 7.10 min Epoch: 037/050 | Batch 000/461 | Cost: 0.4466 Epoch: 037/050 | Batch 150/461 | Cost: 0.3427 Epoch: 037/050 | Batch 300/461 | Cost: 0.3301 Epoch: 037/050 | Batch 450/461 | Cost: 0.4110 Epoch: 037/050 Train ACC: 89.95 | Validation ACC: 93.90 Time elapsed: 7.23 min Epoch: 038/050 | Batch 000/461 | Cost: 0.2470 Epoch: 038/050 | Batch 150/461 | Cost: 0.4719 Epoch: 038/050 | Batch 300/461 | Cost: 0.3253 Epoch: 038/050 | Batch 450/461 | Cost: 0.4324 Epoch: 038/050 Train ACC: 89.35 | Validation ACC: 93.50 Time elapsed: 7.36 min Epoch: 039/050 | Batch 000/461 | Cost: 0.3058 Epoch: 039/050 | Batch 150/461 | Cost: 0.4755 Epoch: 039/050 | Batch 300/461 | Cost: 0.2981 Epoch: 039/050 | Batch 450/461 | Cost: 0.4293 Epoch: 039/050 Train ACC: 89.51 | Validation ACC: 92.90 Time elapsed: 7.48 min Epoch: 040/050 | Batch 000/461 | Cost: 0.3378 Epoch: 040/050 | Batch 150/461 | Cost: 0.5137 Epoch: 040/050 | Batch 300/461 | Cost: 0.2680 Epoch: 040/050 | Batch 450/461 | Cost: 0.3397 Epoch: 040/050 Train ACC: 90.01 | Validation ACC: 93.70 Time elapsed: 7.61 min Epoch: 041/050 | Batch 000/461 | Cost: 0.2766 Epoch: 041/050 | Batch 150/461 | Cost: 0.2959 Epoch: 041/050 | Batch 300/461 | Cost: 0.1930 Epoch: 041/050 | Batch 450/461 | Cost: 0.3735 Epoch: 041/050 Train ACC: 89.45 | Validation ACC: 93.60 Time elapsed: 7.74 min Epoch: 042/050 | Batch 000/461 | Cost: 0.2694 Epoch: 042/050 | Batch 150/461 | Cost: 0.3575 Epoch: 042/050 | Batch 300/461 | Cost: 0.4267 Epoch: 042/050 | Batch 450/461 | Cost: 0.3332 Epoch: 042/050 Train ACC: 89.96 | Validation ACC: 93.30 Time elapsed: 7.86 min Epoch: 043/050 | Batch 000/461 | Cost: 0.2288 Epoch: 043/050 | Batch 150/461 | Cost: 0.4260 Epoch: 043/050 | Batch 300/461 | Cost: 0.2835 Epoch: 043/050 | Batch 450/461 | Cost: 0.2882 Epoch: 043/050 Train ACC: 89.91 | Validation ACC: 93.40 Time elapsed: 7.99 min Epoch: 044/050 | Batch 000/461 | Cost: 0.3211 Epoch: 044/050 | Batch 150/461 | Cost: 0.3061 Epoch: 044/050 | Batch 300/461 | Cost: 0.3137 Epoch: 044/050 | Batch 450/461 | Cost: 0.2978 Epoch: 044/050 Train ACC: 89.63 | Validation ACC: 94.30 Time elapsed: 8.12 min Epoch: 045/050 | Batch 000/461 | Cost: 0.2325 Epoch: 045/050 | Batch 150/461 | Cost: 0.3013 Epoch: 045/050 | Batch 300/461 | Cost: 0.3732 Epoch: 045/050 | Batch 450/461 | Cost: 0.3229 Epoch: 045/050 Train ACC: 90.00 | Validation ACC: 93.80 Time elapsed: 8.25 min Epoch: 046/050 | Batch 000/461 | Cost: 0.2521 Epoch: 046/050 | Batch 150/461 | Cost: 0.4440 Epoch: 046/050 | Batch 300/461 | Cost: 0.3420 Epoch: 046/050 | Batch 450/461 | Cost: 0.4288 Epoch: 046/050 Train ACC: 89.97 | Validation ACC: 93.40 Time elapsed: 8.38 min Epoch: 047/050 | Batch 000/461 | Cost: 0.4605 Epoch: 047/050 | Batch 150/461 | Cost: 0.3261 Epoch: 047/050 | Batch 300/461 | Cost: 0.4493 Epoch: 047/050 | Batch 450/461 | Cost: 0.4902 Epoch: 047/050 Train ACC: 89.60 | Validation ACC: 93.70 Time elapsed: 8.50 min Epoch: 048/050 | Batch 000/461 | Cost: 0.4136 Epoch: 048/050 | Batch 150/461 | Cost: 0.2952 Epoch: 048/050 | Batch 300/461 | Cost: 0.4784 Epoch: 048/050 | Batch 450/461 | Cost: 0.3044 Epoch: 048/050 Train ACC: 90.15 | Validation ACC: 94.60 Time elapsed: 8.63 min Epoch: 049/050 | Batch 000/461 | Cost: 0.3802 Epoch: 049/050 | Batch 150/461 | Cost: 0.4018 Epoch: 049/050 | Batch 300/461 | Cost: 0.3197 Epoch: 049/050 | Batch 450/461 | Cost: 0.4157 Epoch: 049/050 Train ACC: 89.91 | Validation ACC: 93.70 Time elapsed: 8.76 min Epoch: 050/050 | Batch 000/461 | Cost: 0.4057 Epoch: 050/050 | Batch 150/461 | Cost: 0.3687 Epoch: 050/050 | Batch 300/461 | Cost: 0.3552 Epoch: 050/050 | Batch 450/461 | Cost: 0.2707 Epoch: 050/050 Train ACC: 89.91 | Validation ACC: 93.00 Time elapsed: 8.88 min Total Training Time: 8.88 min
# last adjacency matrix
plt.imshow(model.A.to('cpu'));
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list,
np.ones(200,)/200, mode='valid'),
label='Running average')
plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
with torch.set_grad_enabled(False):
test_acc = compute_acc(model=model,
data_loader=test_loader,
device=DEVICE)
valid_acc = compute_acc(model=model,
data_loader=valid_loader,
device=DEVICE)
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 93.00% Test ACC: 90.36%
%watermark -iv
torchvision 0.4.0a0+6b959ee matplotlib 3.1.0 torch 1.2.0 numpy 1.16.4