Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.7.3
IPython 7.6.1

torch 1.2.0
  • Runs on CPU or GPU (if available)

Basic Graph Neural Network with Edge Prediction on MNIST

Implementing a very basic graph neural network (GNN) using a subnetwork for edge prediction.

Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1].

In the related notebook, [gnn-basic-1.ipyb], the adjacency matrix of the pixels was basically just determined by the neighborhood pixels. Using a Gaussian filter, pixels were connected based on their Euclidean distance in the grid. In this notebook, the edges are predicted via a seperate neural network model

self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 64),
                                          nn.ReLU(),
                                          nn.Linear(64, 1),
                                          nn.Tanh())

Using the resulting adjacency matrix $A$, we can compute the output of a layer as

$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$

Here, $A$ is the $N \times N$ adjacency matrix, and $X$ is the $N \times C$ feature matrix (a 2D coordinate array, where $N$ is the total number of pixels -- $28 \times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \times P$, where $P$ would represent the number of classes if we have only a single hidden layer.

Imports

In [2]:
import time
import numpy as np
from scipy.spatial.distance import cdist
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
%matplotlib inline
import matplotlib.pyplot as plt

Settings and Dataset

In [4]:
##########################
### SETTINGS
##########################

# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0005
NUM_EPOCHS = 50
BATCH_SIZE = 128
IMG_SIZE = 28

# Architecture
NUM_CLASSES = 10

MNIST Dataset

In [5]:
train_indices = torch.arange(0, 59000)
valid_indices = torch.arange(59000, 60000)

custom_transform = transforms.Compose([transforms.ToTensor()])


train_and_valid = datasets.MNIST(root='data', 
                                 train=True, 
                                 transform=custom_transform,
                                 download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=custom_transform,
                              download=True)

train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True)

valid_loader = DataLoader(dataset=valid_dataset, 
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=False)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=BATCH_SIZE,
                         num_workers=4,
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [6]:
##########################
### MODEL
##########################


def make_coordinate_array(img_size, out_size=4):
    
    ### Make 2D coordinate array (for MNIST: 784x2)
    n_rows = img_size * img_size
    col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
    coord = np.stack((col, row), axis=2).reshape(-1, 2)
    coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)
    coord = torch.from_numpy(coord).float()
    
    ### Reshape to [N, N, out_size]
    coord = torch.cat((coord.unsqueeze(0).repeat(n_rows, 1,  int(out_size/2-1)),
                            coord.unsqueeze(1).repeat(1, n_rows, 1)), dim=2)
    
    
    return coord

        

class GraphNet(nn.Module):
    def __init__(self, img_size=28, coord_features=4, num_classes=10):
        super(GraphNet, self).__init__()
        
        n_rows = img_size**2
        self.fc = nn.Linear(n_rows, num_classes, bias=False)

        coord = make_coordinate_array(img_size, coord_features)
        self.register_buffer('coord', coord)
        
        ##########
        # Edge Predictor
        self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 32), # coord -> hidden
                                          nn.ReLU(),
                                          nn.Linear(32, 1), # hidden -> edge
                                          nn.Tanh())
        

        

    def forward(self, x):
        B = x.size(0)
        
        ### Predict edges
        self.A = self.pred_edge_fc(self.coord).squeeze()

        ### Reshape Adjacency Matrix
        # [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW
        A_tensor = self.A.unsqueeze(0)
        # [1, N, N] Adj tensor -> [B, N, N] tensor
        A_tensor = self.A.expand(B, -1, -1)
        
        ### Reshape inputs
        # [B, C, H, W] => [B, H*W, 1]
        x_reshape = x.view(B, -1, 1)
        
        # bmm = batch matrix product to sum the neighbor features
        # Input: [B, N, N] x [B, N, 1]
        # Output: [B, N]
        avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1))

        logits = self.fc(avg_neighbor_features)
        probas = F.softmax(logits, dim=1)
        return logits, probas
In [7]:
torch.manual_seed(RANDOM_SEED)
model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)

model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  

Training

In [8]:
def compute_acc(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for features, targets in data_loader:
        features = features.to(device)
        targets = targets.to(device)
        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100
    

start_time = time.time()

cost_list = []
train_acc_list, valid_acc_list = [], []


for epoch in range(NUM_EPOCHS):
    
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
            
        ### FORWARD AND BACK PROP
        logits, probas = model(features)
        cost = F.cross_entropy(logits, targets)
        optimizer.zero_grad()
        
        cost.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        #################################################
        ### CODE ONLY FOR LOGGING BEYOND THIS POINT
        ################################################
        cost_list.append(cost.item())
        if not batch_idx % 150:
            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' 
                   f' Cost: {cost:.4f}')

        

    model.eval()
    with torch.set_grad_enabled(False): # save memory during inference
        
        train_acc = compute_acc(model, train_loader, device=DEVICE)
        valid_acc = compute_acc(model, valid_loader, device=DEVICE)
        
        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n'
              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')
        
        train_acc_list.append(train_acc)
        valid_acc_list.append(valid_acc)
        
    elapsed = (time.time() - start_time)/60
    print(f'Time elapsed: {elapsed:.2f} min')
  
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
Epoch: 001/050 | Batch 000/461 | Cost: 24.2727
Epoch: 001/050 | Batch 150/461 | Cost: 2.2706
Epoch: 001/050 | Batch 300/461 | Cost: 1.8713
Epoch: 001/050 | Batch 450/461 | Cost: 1.5048
Epoch: 001/050
Train ACC: 50.39 | Validation ACC: 54.80
Time elapsed: 0.25 min
Epoch: 002/050 | Batch 000/461 | Cost: 1.4445
Epoch: 002/050 | Batch 150/461 | Cost: 1.3288
Epoch: 002/050 | Batch 300/461 | Cost: 1.1868
Epoch: 002/050 | Batch 450/461 | Cost: 1.2040
Epoch: 002/050
Train ACC: 67.68 | Validation ACC: 71.40
Time elapsed: 0.49 min
Epoch: 003/050 | Batch 000/461 | Cost: 1.2128
Epoch: 003/050 | Batch 150/461 | Cost: 0.9953
Epoch: 003/050 | Batch 300/461 | Cost: 0.9818
Epoch: 003/050 | Batch 450/461 | Cost: 1.0487
Epoch: 003/050
Train ACC: 68.09 | Validation ACC: 73.40
Time elapsed: 0.74 min
Epoch: 004/050 | Batch 000/461 | Cost: 1.0444
Epoch: 004/050 | Batch 150/461 | Cost: 0.9064
Epoch: 004/050 | Batch 300/461 | Cost: 0.9152
Epoch: 004/050 | Batch 450/461 | Cost: 0.7396
Epoch: 004/050
Train ACC: 76.10 | Validation ACC: 80.20
Time elapsed: 0.98 min
Epoch: 005/050 | Batch 000/461 | Cost: 0.7698
Epoch: 005/050 | Batch 150/461 | Cost: 0.8356
Epoch: 005/050 | Batch 300/461 | Cost: 0.6544
Epoch: 005/050 | Batch 450/461 | Cost: 0.8700
Epoch: 005/050
Train ACC: 80.30 | Validation ACC: 84.10
Time elapsed: 1.22 min
Epoch: 006/050 | Batch 000/461 | Cost: 0.6292
Epoch: 006/050 | Batch 150/461 | Cost: 0.7779
Epoch: 006/050 | Batch 300/461 | Cost: 0.5978
Epoch: 006/050 | Batch 450/461 | Cost: 0.6260
Epoch: 006/050
Train ACC: 82.00 | Validation ACC: 85.60
Time elapsed: 1.46 min
Epoch: 007/050 | Batch 000/461 | Cost: 0.7172
Epoch: 007/050 | Batch 150/461 | Cost: 0.6444
Epoch: 007/050 | Batch 300/461 | Cost: 0.5620
Epoch: 007/050 | Batch 450/461 | Cost: 0.5314
Epoch: 007/050
Train ACC: 82.90 | Validation ACC: 86.20
Time elapsed: 1.71 min
Epoch: 008/050 | Batch 000/461 | Cost: 0.6211
Epoch: 008/050 | Batch 150/461 | Cost: 0.5004
Epoch: 008/050 | Batch 300/461 | Cost: 0.5274
Epoch: 008/050 | Batch 450/461 | Cost: 0.6611
Epoch: 008/050
Train ACC: 84.86 | Validation ACC: 87.90
Time elapsed: 1.95 min
Epoch: 009/050 | Batch 000/461 | Cost: 0.4017
Epoch: 009/050 | Batch 150/461 | Cost: 0.7080
Epoch: 009/050 | Batch 300/461 | Cost: 0.4298
Epoch: 009/050 | Batch 450/461 | Cost: 0.4516
Epoch: 009/050
Train ACC: 85.89 | Validation ACC: 90.20
Time elapsed: 2.19 min
Epoch: 010/050 | Batch 000/461 | Cost: 0.4571
Epoch: 010/050 | Batch 150/461 | Cost: 0.4976
Epoch: 010/050 | Batch 300/461 | Cost: 0.6208
Epoch: 010/050 | Batch 450/461 | Cost: 0.3780
Epoch: 010/050
Train ACC: 85.84 | Validation ACC: 89.40
Time elapsed: 2.43 min
Epoch: 011/050 | Batch 000/461 | Cost: 0.5262
Epoch: 011/050 | Batch 150/461 | Cost: 0.4255
Epoch: 011/050 | Batch 300/461 | Cost: 0.3840
Epoch: 011/050 | Batch 450/461 | Cost: 0.4941
Epoch: 011/050
Train ACC: 85.73 | Validation ACC: 90.40
Time elapsed: 2.68 min
Epoch: 012/050 | Batch 000/461 | Cost: 0.3425
Epoch: 012/050 | Batch 150/461 | Cost: 0.5059
Epoch: 012/050 | Batch 300/461 | Cost: 0.6590
Epoch: 012/050 | Batch 450/461 | Cost: 0.5481
Epoch: 012/050
Train ACC: 87.50 | Validation ACC: 91.30
Time elapsed: 2.92 min
Epoch: 013/050 | Batch 000/461 | Cost: 0.6081
Epoch: 013/050 | Batch 150/461 | Cost: 0.4584
Epoch: 013/050 | Batch 300/461 | Cost: 0.2856
Epoch: 013/050 | Batch 450/461 | Cost: 0.4324
Epoch: 013/050
Train ACC: 87.35 | Validation ACC: 91.20
Time elapsed: 3.17 min
Epoch: 014/050 | Batch 000/461 | Cost: 0.4685
Epoch: 014/050 | Batch 150/461 | Cost: 0.4492
Epoch: 014/050 | Batch 300/461 | Cost: 0.3913
Epoch: 014/050 | Batch 450/461 | Cost: 0.5154
Epoch: 014/050
Train ACC: 86.71 | Validation ACC: 91.20
Time elapsed: 3.41 min
Epoch: 015/050 | Batch 000/461 | Cost: 0.4526
Epoch: 015/050 | Batch 150/461 | Cost: 0.4834
Epoch: 015/050 | Batch 300/461 | Cost: 0.5208
Epoch: 015/050 | Batch 450/461 | Cost: 0.3536
Epoch: 015/050
Train ACC: 85.21 | Validation ACC: 89.50
Time elapsed: 3.66 min
Epoch: 016/050 | Batch 000/461 | Cost: 0.6614
Epoch: 016/050 | Batch 150/461 | Cost: 0.3036
Epoch: 016/050 | Batch 300/461 | Cost: 0.3766
Epoch: 016/050 | Batch 450/461 | Cost: 0.4550
Epoch: 016/050
Train ACC: 86.97 | Validation ACC: 92.10
Time elapsed: 3.91 min
Epoch: 017/050 | Batch 000/461 | Cost: 0.6241
Epoch: 017/050 | Batch 150/461 | Cost: 0.3934
Epoch: 017/050 | Batch 300/461 | Cost: 0.4330
Epoch: 017/050 | Batch 450/461 | Cost: 0.5914
Epoch: 017/050
Train ACC: 88.12 | Validation ACC: 91.60
Time elapsed: 4.15 min
Epoch: 018/050 | Batch 000/461 | Cost: 0.3769
Epoch: 018/050 | Batch 150/461 | Cost: 0.4817
Epoch: 018/050 | Batch 300/461 | Cost: 0.4103
Epoch: 018/050 | Batch 450/461 | Cost: 0.3727
Epoch: 018/050
Train ACC: 86.58 | Validation ACC: 90.90
Time elapsed: 4.40 min
Epoch: 019/050 | Batch 000/461 | Cost: 0.4098
Epoch: 019/050 | Batch 150/461 | Cost: 0.4435
Epoch: 019/050 | Batch 300/461 | Cost: 0.2952
Epoch: 019/050 | Batch 450/461 | Cost: 0.3328
Epoch: 019/050
Train ACC: 88.65 | Validation ACC: 92.00
Time elapsed: 4.64 min
Epoch: 020/050 | Batch 000/461 | Cost: 0.5363
Epoch: 020/050 | Batch 150/461 | Cost: 0.3143
Epoch: 020/050 | Batch 300/461 | Cost: 0.5186
Epoch: 020/050 | Batch 450/461 | Cost: 0.3806
Epoch: 020/050
Train ACC: 88.95 | Validation ACC: 92.70
Time elapsed: 4.89 min
Epoch: 021/050 | Batch 000/461 | Cost: 0.3810
Epoch: 021/050 | Batch 150/461 | Cost: 0.2470
Epoch: 021/050 | Batch 300/461 | Cost: 0.6154
Epoch: 021/050 | Batch 450/461 | Cost: 0.3651
Epoch: 021/050
Train ACC: 88.31 | Validation ACC: 92.40
Time elapsed: 5.13 min
Epoch: 022/050 | Batch 000/461 | Cost: 0.3704
Epoch: 022/050 | Batch 150/461 | Cost: 0.4338
Epoch: 022/050 | Batch 300/461 | Cost: 0.4197
Epoch: 022/050 | Batch 450/461 | Cost: 0.3304
Epoch: 022/050
Train ACC: 88.62 | Validation ACC: 91.90
Time elapsed: 5.31 min
Epoch: 023/050 | Batch 000/461 | Cost: 0.2825
Epoch: 023/050 | Batch 150/461 | Cost: 0.4302
Epoch: 023/050 | Batch 300/461 | Cost: 0.4738
Epoch: 023/050 | Batch 450/461 | Cost: 0.4362
Epoch: 023/050
Train ACC: 89.02 | Validation ACC: 92.80
Time elapsed: 5.44 min
Epoch: 024/050 | Batch 000/461 | Cost: 0.2097
Epoch: 024/050 | Batch 150/461 | Cost: 0.4440
Epoch: 024/050 | Batch 300/461 | Cost: 0.4467
Epoch: 024/050 | Batch 450/461 | Cost: 0.2744
Epoch: 024/050
Train ACC: 88.82 | Validation ACC: 92.40
Time elapsed: 5.57 min
Epoch: 025/050 | Batch 000/461 | Cost: 0.2734
Epoch: 025/050 | Batch 150/461 | Cost: 0.3980
Epoch: 025/050 | Batch 300/461 | Cost: 0.4395
Epoch: 025/050 | Batch 450/461 | Cost: 0.2336
Epoch: 025/050
Train ACC: 89.59 | Validation ACC: 93.90
Time elapsed: 5.70 min
Epoch: 026/050 | Batch 000/461 | Cost: 0.3138
Epoch: 026/050 | Batch 150/461 | Cost: 0.3772
Epoch: 026/050 | Batch 300/461 | Cost: 0.2955
Epoch: 026/050 | Batch 450/461 | Cost: 0.3747
Epoch: 026/050
Train ACC: 88.71 | Validation ACC: 92.70
Time elapsed: 5.82 min
Epoch: 027/050 | Batch 000/461 | Cost: 0.4107
Epoch: 027/050 | Batch 150/461 | Cost: 0.4375
Epoch: 027/050 | Batch 300/461 | Cost: 0.3802
Epoch: 027/050 | Batch 450/461 | Cost: 0.3240
Epoch: 027/050
Train ACC: 87.90 | Validation ACC: 91.60
Time elapsed: 5.95 min
Epoch: 028/050 | Batch 000/461 | Cost: 0.5124
Epoch: 028/050 | Batch 150/461 | Cost: 0.4980
Epoch: 028/050 | Batch 300/461 | Cost: 0.3937
Epoch: 028/050 | Batch 450/461 | Cost: 0.2704
Epoch: 028/050
Train ACC: 89.08 | Validation ACC: 92.30
Time elapsed: 6.08 min
Epoch: 029/050 | Batch 000/461 | Cost: 0.3328
Epoch: 029/050 | Batch 150/461 | Cost: 0.3022
Epoch: 029/050 | Batch 300/461 | Cost: 0.3222
Epoch: 029/050 | Batch 450/461 | Cost: 0.3084
Epoch: 029/050
Train ACC: 89.30 | Validation ACC: 93.90
Time elapsed: 6.21 min
Epoch: 030/050 | Batch 000/461 | Cost: 0.4667
Epoch: 030/050 | Batch 150/461 | Cost: 0.3290
Epoch: 030/050 | Batch 300/461 | Cost: 0.3261
Epoch: 030/050 | Batch 450/461 | Cost: 0.3347
Epoch: 030/050
Train ACC: 89.17 | Validation ACC: 93.60
Time elapsed: 6.33 min
Epoch: 031/050 | Batch 000/461 | Cost: 0.3486
Epoch: 031/050 | Batch 150/461 | Cost: 0.2426
Epoch: 031/050 | Batch 300/461 | Cost: 0.2748
Epoch: 031/050 | Batch 450/461 | Cost: 0.2072
Epoch: 031/050
Train ACC: 89.17 | Validation ACC: 93.20
Time elapsed: 6.46 min
Epoch: 032/050 | Batch 000/461 | Cost: 0.3423
Epoch: 032/050 | Batch 150/461 | Cost: 0.4924
Epoch: 032/050 | Batch 300/461 | Cost: 0.4072
Epoch: 032/050 | Batch 450/461 | Cost: 0.3611
Epoch: 032/050
Train ACC: 89.83 | Validation ACC: 94.30
Time elapsed: 6.59 min
Epoch: 033/050 | Batch 000/461 | Cost: 0.2461
Epoch: 033/050 | Batch 150/461 | Cost: 0.2343
Epoch: 033/050 | Batch 300/461 | Cost: 0.2891
Epoch: 033/050 | Batch 450/461 | Cost: 0.3772
Epoch: 033/050
Train ACC: 88.81 | Validation ACC: 92.40
Time elapsed: 6.72 min
Epoch: 034/050 | Batch 000/461 | Cost: 0.3052
Epoch: 034/050 | Batch 150/461 | Cost: 0.5129
Epoch: 034/050 | Batch 300/461 | Cost: 0.3810
Epoch: 034/050 | Batch 450/461 | Cost: 0.2906
Epoch: 034/050
Train ACC: 89.34 | Validation ACC: 93.10
Time elapsed: 6.85 min
Epoch: 035/050 | Batch 000/461 | Cost: 0.3604
Epoch: 035/050 | Batch 150/461 | Cost: 0.3832
Epoch: 035/050 | Batch 300/461 | Cost: 0.3632
Epoch: 035/050 | Batch 450/461 | Cost: 0.3345
Epoch: 035/050
Train ACC: 89.74 | Validation ACC: 93.10
Time elapsed: 6.98 min
Epoch: 036/050 | Batch 000/461 | Cost: 0.3382
Epoch: 036/050 | Batch 150/461 | Cost: 0.3754
Epoch: 036/050 | Batch 300/461 | Cost: 0.4120
Epoch: 036/050 | Batch 450/461 | Cost: 0.4710
Epoch: 036/050
Train ACC: 89.10 | Validation ACC: 93.90
Time elapsed: 7.10 min
Epoch: 037/050 | Batch 000/461 | Cost: 0.4466
Epoch: 037/050 | Batch 150/461 | Cost: 0.3427
Epoch: 037/050 | Batch 300/461 | Cost: 0.3301
Epoch: 037/050 | Batch 450/461 | Cost: 0.4110
Epoch: 037/050
Train ACC: 89.95 | Validation ACC: 93.90
Time elapsed: 7.23 min
Epoch: 038/050 | Batch 000/461 | Cost: 0.2470
Epoch: 038/050 | Batch 150/461 | Cost: 0.4719
Epoch: 038/050 | Batch 300/461 | Cost: 0.3253
Epoch: 038/050 | Batch 450/461 | Cost: 0.4324
Epoch: 038/050
Train ACC: 89.35 | Validation ACC: 93.50
Time elapsed: 7.36 min
Epoch: 039/050 | Batch 000/461 | Cost: 0.3058
Epoch: 039/050 | Batch 150/461 | Cost: 0.4755
Epoch: 039/050 | Batch 300/461 | Cost: 0.2981
Epoch: 039/050 | Batch 450/461 | Cost: 0.4293
Epoch: 039/050
Train ACC: 89.51 | Validation ACC: 92.90
Time elapsed: 7.48 min
Epoch: 040/050 | Batch 000/461 | Cost: 0.3378
Epoch: 040/050 | Batch 150/461 | Cost: 0.5137
Epoch: 040/050 | Batch 300/461 | Cost: 0.2680
Epoch: 040/050 | Batch 450/461 | Cost: 0.3397
Epoch: 040/050
Train ACC: 90.01 | Validation ACC: 93.70
Time elapsed: 7.61 min
Epoch: 041/050 | Batch 000/461 | Cost: 0.2766
Epoch: 041/050 | Batch 150/461 | Cost: 0.2959
Epoch: 041/050 | Batch 300/461 | Cost: 0.1930
Epoch: 041/050 | Batch 450/461 | Cost: 0.3735
Epoch: 041/050
Train ACC: 89.45 | Validation ACC: 93.60
Time elapsed: 7.74 min
Epoch: 042/050 | Batch 000/461 | Cost: 0.2694
Epoch: 042/050 | Batch 150/461 | Cost: 0.3575
Epoch: 042/050 | Batch 300/461 | Cost: 0.4267
Epoch: 042/050 | Batch 450/461 | Cost: 0.3332
Epoch: 042/050
Train ACC: 89.96 | Validation ACC: 93.30
Time elapsed: 7.86 min
Epoch: 043/050 | Batch 000/461 | Cost: 0.2288
Epoch: 043/050 | Batch 150/461 | Cost: 0.4260
Epoch: 043/050 | Batch 300/461 | Cost: 0.2835
Epoch: 043/050 | Batch 450/461 | Cost: 0.2882
Epoch: 043/050
Train ACC: 89.91 | Validation ACC: 93.40
Time elapsed: 7.99 min
Epoch: 044/050 | Batch 000/461 | Cost: 0.3211
Epoch: 044/050 | Batch 150/461 | Cost: 0.3061
Epoch: 044/050 | Batch 300/461 | Cost: 0.3137
Epoch: 044/050 | Batch 450/461 | Cost: 0.2978
Epoch: 044/050
Train ACC: 89.63 | Validation ACC: 94.30
Time elapsed: 8.12 min
Epoch: 045/050 | Batch 000/461 | Cost: 0.2325
Epoch: 045/050 | Batch 150/461 | Cost: 0.3013
Epoch: 045/050 | Batch 300/461 | Cost: 0.3732
Epoch: 045/050 | Batch 450/461 | Cost: 0.3229
Epoch: 045/050
Train ACC: 90.00 | Validation ACC: 93.80
Time elapsed: 8.25 min
Epoch: 046/050 | Batch 000/461 | Cost: 0.2521
Epoch: 046/050 | Batch 150/461 | Cost: 0.4440
Epoch: 046/050 | Batch 300/461 | Cost: 0.3420
Epoch: 046/050 | Batch 450/461 | Cost: 0.4288
Epoch: 046/050
Train ACC: 89.97 | Validation ACC: 93.40
Time elapsed: 8.38 min
Epoch: 047/050 | Batch 000/461 | Cost: 0.4605
Epoch: 047/050 | Batch 150/461 | Cost: 0.3261
Epoch: 047/050 | Batch 300/461 | Cost: 0.4493
Epoch: 047/050 | Batch 450/461 | Cost: 0.4902
Epoch: 047/050
Train ACC: 89.60 | Validation ACC: 93.70
Time elapsed: 8.50 min
Epoch: 048/050 | Batch 000/461 | Cost: 0.4136
Epoch: 048/050 | Batch 150/461 | Cost: 0.2952
Epoch: 048/050 | Batch 300/461 | Cost: 0.4784
Epoch: 048/050 | Batch 450/461 | Cost: 0.3044
Epoch: 048/050
Train ACC: 90.15 | Validation ACC: 94.60
Time elapsed: 8.63 min
Epoch: 049/050 | Batch 000/461 | Cost: 0.3802
Epoch: 049/050 | Batch 150/461 | Cost: 0.4018
Epoch: 049/050 | Batch 300/461 | Cost: 0.3197
Epoch: 049/050 | Batch 450/461 | Cost: 0.4157
Epoch: 049/050
Train ACC: 89.91 | Validation ACC: 93.70
Time elapsed: 8.76 min
Epoch: 050/050 | Batch 000/461 | Cost: 0.4057
Epoch: 050/050 | Batch 150/461 | Cost: 0.3687
Epoch: 050/050 | Batch 300/461 | Cost: 0.3552
Epoch: 050/050 | Batch 450/461 | Cost: 0.2707
Epoch: 050/050
Train ACC: 89.91 | Validation ACC: 93.00
Time elapsed: 8.88 min
Total Training Time: 8.88 min

Evaluation

In [9]:
# last adjacency matrix

plt.imshow(model.A.to('cpu'));
In [10]:
plt.plot(cost_list, label='Minibatch cost')
plt.plot(np.convolve(cost_list, 
                     np.ones(200,)/200, mode='valid'), 
         label='Running average')

plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()
In [11]:
plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')
plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
In [12]:
with torch.set_grad_enabled(False):
    test_acc = compute_acc(model=model,
                           data_loader=test_loader,
                           device=DEVICE)
    
    valid_acc = compute_acc(model=model,
                            data_loader=valid_loader,
                            device=DEVICE)
    

print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')
Validation ACC: 93.00%
Test ACC: 90.36%
In [13]:
%watermark -iv
torchvision 0.4.0a0+6b959ee
matplotlib  3.1.0
torch       1.2.0
numpy       1.16.4