Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.6.8
IPython 7.2.0

torch 1.0.0
  • Runs on CPU or GPU (if available)

Model Zoo -- Convolutional Variational Autoencoder

A simple convolutional variational autoencoder that compresses 768-pixel MNIST images down to a 15-pixel latent vector representation.

Imports

In [2]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
##########################
### SETTINGS
##########################

# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)

# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128

# Architecture
num_features = 784
num_latent = 15


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Device: cuda:1
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [4]:
## ### MODEL
##########################

class ConvVariationalAutoencoder(torch.nn.Module):

    def __init__(self, num_features, num_latent):
        super(ConvVariationalAutoencoder, self).__init__()
        
        ###############
        # ENCODER
        ##############
        
        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.enc_conv_1 = torch.nn.Conv2d(in_channels=1,
                                          out_channels=16,
                                          kernel_size=(6, 6),
                                          stride=(2, 2),
                                          padding=0)

        self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,
                                          out_channels=32,
                                          kernel_size=(4, 4),
                                          stride=(2, 2),
                                          padding=0)                 
        
        self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,
                                          out_channels=64,
                                          kernel_size=(2, 2),
                                          stride=(2, 2),
                                          padding=0)                     
        
        self.z_mean = torch.nn.Linear(64*2*2, num_latent)
        # in the original paper (Kingma & Welling 2015, we use
        # have a z_mean and z_var, but the problem is that
        # the z_var can be negative, which would cause issues
        # in the log later. Hence we assume that latent vector
        # has a z_mean and z_log_var component, and when we need
        # the regular variance or std_dev, we simply use 
        # an exponential function
        self.z_log_var = torch.nn.Linear(64*2*2, num_latent)
        
        
        
        ###############
        # DECODER
        ##############
        
        self.dec_linear_1 = torch.nn.Linear(num_latent, 64*2*2)
               
        self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,
                                                     out_channels=32,
                                                     kernel_size=(2, 2),
                                                     stride=(2, 2),
                                                     padding=0)
                                 
        self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,
                                                     out_channels=16,
                                                     kernel_size=(4, 4),
                                                     stride=(3, 3),
                                                     padding=1)
        
        self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,
                                                     out_channels=1,
                                                     kernel_size=(6, 6),
                                                     stride=(3, 3),
                                                     padding=4)


    def reparameterize(self, z_mu, z_log_var):
        # Sample epsilon from standard normal distribution
        eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
        # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
        # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z
        
    def encoder(self, features):
        x = self.enc_conv_1(features)
        x = F.leaky_relu(x)
        #print('conv1 out:', x.size())
        
        x = self.enc_conv_2(x)
        x = F.leaky_relu(x)
        #print('conv2 out:', x.size())
        
        x = self.enc_conv_3(x)
        x = F.leaky_relu(x)
        #print('conv3 out:', x.size())
        
        z_mean = self.z_mean(x.view(-1, 64*2*2))
        z_log_var = self.z_log_var(x.view(-1, 64*2*2))
        encoded = self.reparameterize(z_mean, z_log_var)
        
        return z_mean, z_log_var, encoded
    
    def decoder(self, encoded):
        x = self.dec_linear_1(encoded)
        x = x.view(-1, 64, 2, 2)
        
        x = self.dec_deconv_1(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        x = self.dec_deconv_2(x)
        x = F.leaky_relu(x)
        #print('deconv2 out:', x.size())
        
        x = self.dec_deconv_3(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, features):
        
        z_mean, z_log_var, encoded = self.encoder(features)
        decoded = self.decoder(encoded)
        
        return z_mean, z_log_var, encoded, decoded

    
torch.manual_seed(random_seed)
model = ConvVariationalAutoencoder(num_features,
                                   num_latent)
model = model.to(device)
    

##########################
### COST AND OPTIMIZER
##########################

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Training

In [5]:
start_time = time.time()

for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        # don't need labels, only the images (features)
        features = features.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        optimizer.zero_grad()
        cost.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader), cost))
            
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 70508.8125
Epoch: 001/050 | Batch 050/469 | Cost: 69360.6953
Epoch: 001/050 | Batch 100/469 | Cost: 41214.9258
Epoch: 001/050 | Batch 150/469 | Cost: 34571.9688
Epoch: 001/050 | Batch 200/469 | Cost: 28148.7500
Epoch: 001/050 | Batch 250/469 | Cost: 29157.5918
Epoch: 001/050 | Batch 300/469 | Cost: 27405.8359
Epoch: 001/050 | Batch 350/469 | Cost: 26417.7324
Epoch: 001/050 | Batch 400/469 | Cost: 25685.3320
Epoch: 001/050 | Batch 450/469 | Cost: 25538.5293
Time elapsed: 0.14 min
Epoch: 002/050 | Batch 000/469 | Cost: 25301.1914
Epoch: 002/050 | Batch 050/469 | Cost: 24180.0254
Epoch: 002/050 | Batch 100/469 | Cost: 24124.1309
Epoch: 002/050 | Batch 150/469 | Cost: 24285.9629
Epoch: 002/050 | Batch 200/469 | Cost: 24396.1914
Epoch: 002/050 | Batch 250/469 | Cost: 23214.3711
Epoch: 002/050 | Batch 300/469 | Cost: 23690.0957
Epoch: 002/050 | Batch 350/469 | Cost: 23472.6250
Epoch: 002/050 | Batch 400/469 | Cost: 22830.4551
Epoch: 002/050 | Batch 450/469 | Cost: 23235.6602
Time elapsed: 0.28 min
Epoch: 003/050 | Batch 000/469 | Cost: 21816.1934
Epoch: 003/050 | Batch 050/469 | Cost: 22230.9316
Epoch: 003/050 | Batch 100/469 | Cost: 21450.3535
Epoch: 003/050 | Batch 150/469 | Cost: 22165.5781
Epoch: 003/050 | Batch 200/469 | Cost: 20737.0566
Epoch: 003/050 | Batch 250/469 | Cost: 22005.0625
Epoch: 003/050 | Batch 300/469 | Cost: 21199.4375
Epoch: 003/050 | Batch 350/469 | Cost: 20788.6172
Epoch: 003/050 | Batch 400/469 | Cost: 21091.8125
Epoch: 003/050 | Batch 450/469 | Cost: 20739.7637
Time elapsed: 0.43 min
Epoch: 004/050 | Batch 000/469 | Cost: 20847.3320
Epoch: 004/050 | Batch 050/469 | Cost: 21085.3945
Epoch: 004/050 | Batch 100/469 | Cost: 19874.4863
Epoch: 004/050 | Batch 150/469 | Cost: 20014.3594
Epoch: 004/050 | Batch 200/469 | Cost: 19658.4062
Epoch: 004/050 | Batch 250/469 | Cost: 19592.5801
Epoch: 004/050 | Batch 300/469 | Cost: 20177.1680
Epoch: 004/050 | Batch 350/469 | Cost: 19534.2344
Epoch: 004/050 | Batch 400/469 | Cost: 19567.2852
Epoch: 004/050 | Batch 450/469 | Cost: 19308.6367
Time elapsed: 0.57 min
Epoch: 005/050 | Batch 000/469 | Cost: 18299.0762
Epoch: 005/050 | Batch 050/469 | Cost: 17929.8359
Epoch: 005/050 | Batch 100/469 | Cost: 19014.4316
Epoch: 005/050 | Batch 150/469 | Cost: 18907.9668
Epoch: 005/050 | Batch 200/469 | Cost: 18992.6836
Epoch: 005/050 | Batch 250/469 | Cost: 18611.7383
Epoch: 005/050 | Batch 300/469 | Cost: 18453.7012
Epoch: 005/050 | Batch 350/469 | Cost: 18959.7227
Epoch: 005/050 | Batch 400/469 | Cost: 18798.1758
Epoch: 005/050 | Batch 450/469 | Cost: 18019.3672
Time elapsed: 0.71 min
Epoch: 006/050 | Batch 000/469 | Cost: 18124.0820
Epoch: 006/050 | Batch 050/469 | Cost: 18439.6680
Epoch: 006/050 | Batch 100/469 | Cost: 17569.6094
Epoch: 006/050 | Batch 150/469 | Cost: 18261.6934
Epoch: 006/050 | Batch 200/469 | Cost: 17973.9492
Epoch: 006/050 | Batch 250/469 | Cost: 16992.7305
Epoch: 006/050 | Batch 300/469 | Cost: 18452.1992
Epoch: 006/050 | Batch 350/469 | Cost: 17165.4297
Epoch: 006/050 | Batch 400/469 | Cost: 18000.2754
Epoch: 006/050 | Batch 450/469 | Cost: 16839.3262
Time elapsed: 0.86 min
Epoch: 007/050 | Batch 000/469 | Cost: 17863.0645
Epoch: 007/050 | Batch 050/469 | Cost: 17572.0059
Epoch: 007/050 | Batch 100/469 | Cost: 17348.5625
Epoch: 007/050 | Batch 150/469 | Cost: 17124.4922
Epoch: 007/050 | Batch 200/469 | Cost: 17443.2051
Epoch: 007/050 | Batch 250/469 | Cost: 17221.6523
Epoch: 007/050 | Batch 300/469 | Cost: 17059.4297
Epoch: 007/050 | Batch 350/469 | Cost: 17353.8359
Epoch: 007/050 | Batch 400/469 | Cost: 18116.3086
Epoch: 007/050 | Batch 450/469 | Cost: 17090.7910
Time elapsed: 1.00 min
Epoch: 008/050 | Batch 000/469 | Cost: 17174.0098
Epoch: 008/050 | Batch 050/469 | Cost: 16741.3477
Epoch: 008/050 | Batch 100/469 | Cost: 16833.8691
Epoch: 008/050 | Batch 150/469 | Cost: 17041.8145
Epoch: 008/050 | Batch 200/469 | Cost: 16583.4785
Epoch: 008/050 | Batch 250/469 | Cost: 17148.7363
Epoch: 008/050 | Batch 300/469 | Cost: 16401.9492
Epoch: 008/050 | Batch 350/469 | Cost: 16366.9717
Epoch: 008/050 | Batch 400/469 | Cost: 16309.4883
Epoch: 008/050 | Batch 450/469 | Cost: 16813.7383
Time elapsed: 1.14 min
Epoch: 009/050 | Batch 000/469 | Cost: 16475.1348
Epoch: 009/050 | Batch 050/469 | Cost: 16717.6797
Epoch: 009/050 | Batch 100/469 | Cost: 16681.8125
Epoch: 009/050 | Batch 150/469 | Cost: 16367.4902
Epoch: 009/050 | Batch 200/469 | Cost: 16425.5449
Epoch: 009/050 | Batch 250/469 | Cost: 16841.6738
Epoch: 009/050 | Batch 300/469 | Cost: 16003.4609
Epoch: 009/050 | Batch 350/469 | Cost: 15953.0732
Epoch: 009/050 | Batch 400/469 | Cost: 15981.5557
Epoch: 009/050 | Batch 450/469 | Cost: 15866.3105
Time elapsed: 1.28 min
Epoch: 010/050 | Batch 000/469 | Cost: 16785.4121
Epoch: 010/050 | Batch 050/469 | Cost: 16397.5430
Epoch: 010/050 | Batch 100/469 | Cost: 16289.5566
Epoch: 010/050 | Batch 150/469 | Cost: 16549.7812
Epoch: 010/050 | Batch 200/469 | Cost: 16190.5586
Epoch: 010/050 | Batch 250/469 | Cost: 15208.5176
Epoch: 010/050 | Batch 300/469 | Cost: 15649.6221
Epoch: 010/050 | Batch 350/469 | Cost: 15850.7285
Epoch: 010/050 | Batch 400/469 | Cost: 15607.8145
Epoch: 010/050 | Batch 450/469 | Cost: 16352.9707
Time elapsed: 1.42 min
Epoch: 011/050 | Batch 000/469 | Cost: 14833.6748
Epoch: 011/050 | Batch 050/469 | Cost: 14793.8174
Epoch: 011/050 | Batch 100/469 | Cost: 16031.2539
Epoch: 011/050 | Batch 150/469 | Cost: 16403.2148
Epoch: 011/050 | Batch 200/469 | Cost: 16180.4619
Epoch: 011/050 | Batch 250/469 | Cost: 15964.9424
Epoch: 011/050 | Batch 300/469 | Cost: 16027.1377
Epoch: 011/050 | Batch 350/469 | Cost: 16350.3730
Epoch: 011/050 | Batch 400/469 | Cost: 15546.7812
Epoch: 011/050 | Batch 450/469 | Cost: 15494.3408
Time elapsed: 1.56 min
Epoch: 012/050 | Batch 000/469 | Cost: 15366.8662
Epoch: 012/050 | Batch 050/469 | Cost: 15567.0410
Epoch: 012/050 | Batch 100/469 | Cost: 15825.4131
Epoch: 012/050 | Batch 150/469 | Cost: 15356.7363
Epoch: 012/050 | Batch 200/469 | Cost: 16218.4111
Epoch: 012/050 | Batch 250/469 | Cost: 15840.4629
Epoch: 012/050 | Batch 300/469 | Cost: 15789.5957
Epoch: 012/050 | Batch 350/469 | Cost: 16290.2920
Epoch: 012/050 | Batch 400/469 | Cost: 16000.1152
Epoch: 012/050 | Batch 450/469 | Cost: 15458.4883
Time elapsed: 1.71 min
Epoch: 013/050 | Batch 000/469 | Cost: 14845.1387
Epoch: 013/050 | Batch 050/469 | Cost: 14813.1328
Epoch: 013/050 | Batch 100/469 | Cost: 15130.9199
Epoch: 013/050 | Batch 150/469 | Cost: 15422.9141
Epoch: 013/050 | Batch 200/469 | Cost: 15566.4805
Epoch: 013/050 | Batch 250/469 | Cost: 15794.4580
Epoch: 013/050 | Batch 300/469 | Cost: 15083.1582
Epoch: 013/050 | Batch 350/469 | Cost: 15447.7637
Epoch: 013/050 | Batch 400/469 | Cost: 15675.3779
Epoch: 013/050 | Batch 450/469 | Cost: 15165.6543
Time elapsed: 1.85 min
Epoch: 014/050 | Batch 000/469 | Cost: 15194.8164
Epoch: 014/050 | Batch 050/469 | Cost: 15119.1504
Epoch: 014/050 | Batch 100/469 | Cost: 15796.2129
Epoch: 014/050 | Batch 150/469 | Cost: 14884.1680
Epoch: 014/050 | Batch 200/469 | Cost: 15225.4922
Epoch: 014/050 | Batch 250/469 | Cost: 15586.4531
Epoch: 014/050 | Batch 300/469 | Cost: 14798.0352
Epoch: 014/050 | Batch 350/469 | Cost: 15295.6680
Epoch: 014/050 | Batch 400/469 | Cost: 15782.0469
Epoch: 014/050 | Batch 450/469 | Cost: 15226.4424
Time elapsed: 1.99 min
Epoch: 015/050 | Batch 000/469 | Cost: 15213.7441
Epoch: 015/050 | Batch 050/469 | Cost: 15049.6631
Epoch: 015/050 | Batch 100/469 | Cost: 15464.3105
Epoch: 015/050 | Batch 150/469 | Cost: 15114.6406
Epoch: 015/050 | Batch 200/469 | Cost: 15309.3145
Epoch: 015/050 | Batch 250/469 | Cost: 14940.2734
Epoch: 015/050 | Batch 300/469 | Cost: 15016.6035
Epoch: 015/050 | Batch 350/469 | Cost: 15046.3008
Epoch: 015/050 | Batch 400/469 | Cost: 15167.2373
Epoch: 015/050 | Batch 450/469 | Cost: 14859.8359
Time elapsed: 2.13 min
Epoch: 016/050 | Batch 000/469 | Cost: 15028.2578
Epoch: 016/050 | Batch 050/469 | Cost: 14834.3887
Epoch: 016/050 | Batch 100/469 | Cost: 15176.1133
Epoch: 016/050 | Batch 150/469 | Cost: 15468.7812
Epoch: 016/050 | Batch 200/469 | Cost: 15083.7363
Epoch: 016/050 | Batch 250/469 | Cost: 14691.1562
Epoch: 016/050 | Batch 300/469 | Cost: 15369.2461
Epoch: 016/050 | Batch 350/469 | Cost: 14979.9854
Epoch: 016/050 | Batch 400/469 | Cost: 14710.5820
Epoch: 016/050 | Batch 450/469 | Cost: 15753.2812
Time elapsed: 2.27 min
Epoch: 017/050 | Batch 000/469 | Cost: 15676.1680
Epoch: 017/050 | Batch 050/469 | Cost: 15123.8203
Epoch: 017/050 | Batch 100/469 | Cost: 15131.5918
Epoch: 017/050 | Batch 150/469 | Cost: 14856.3496
Epoch: 017/050 | Batch 200/469 | Cost: 15176.2002
Epoch: 017/050 | Batch 250/469 | Cost: 14768.6816
Epoch: 017/050 | Batch 300/469 | Cost: 14871.7480
Epoch: 017/050 | Batch 350/469 | Cost: 14418.3633
Epoch: 017/050 | Batch 400/469 | Cost: 15398.9326
Epoch: 017/050 | Batch 450/469 | Cost: 14675.2832
Time elapsed: 2.40 min
Epoch: 018/050 | Batch 000/469 | Cost: 15558.1592
Epoch: 018/050 | Batch 050/469 | Cost: 14836.9766
Epoch: 018/050 | Batch 100/469 | Cost: 14535.8203
Epoch: 018/050 | Batch 150/469 | Cost: 15062.1992
Epoch: 018/050 | Batch 200/469 | Cost: 15094.6914
Epoch: 018/050 | Batch 250/469 | Cost: 15006.5684
Epoch: 018/050 | Batch 300/469 | Cost: 14656.5703
Epoch: 018/050 | Batch 350/469 | Cost: 15232.4990
Epoch: 018/050 | Batch 400/469 | Cost: 15159.4854
Epoch: 018/050 | Batch 450/469 | Cost: 15619.9785
Time elapsed: 2.53 min
Epoch: 019/050 | Batch 000/469 | Cost: 14647.2051
Epoch: 019/050 | Batch 050/469 | Cost: 15262.9062
Epoch: 019/050 | Batch 100/469 | Cost: 15305.6738
Epoch: 019/050 | Batch 150/469 | Cost: 14550.4102
Epoch: 019/050 | Batch 200/469 | Cost: 15431.4395
Epoch: 019/050 | Batch 250/469 | Cost: 15205.6074
Epoch: 019/050 | Batch 300/469 | Cost: 15149.4453
Epoch: 019/050 | Batch 350/469 | Cost: 14836.1543
Epoch: 019/050 | Batch 400/469 | Cost: 14699.8994
Epoch: 019/050 | Batch 450/469 | Cost: 15564.8604
Time elapsed: 2.67 min
Epoch: 020/050 | Batch 000/469 | Cost: 15190.4043
Epoch: 020/050 | Batch 050/469 | Cost: 15331.2246
Epoch: 020/050 | Batch 100/469 | Cost: 14559.0176
Epoch: 020/050 | Batch 150/469 | Cost: 14311.1699
Epoch: 020/050 | Batch 200/469 | Cost: 14561.7070
Epoch: 020/050 | Batch 250/469 | Cost: 15366.1982
Epoch: 020/050 | Batch 300/469 | Cost: 14740.9365
Epoch: 020/050 | Batch 350/469 | Cost: 14924.1406
Epoch: 020/050 | Batch 400/469 | Cost: 14399.0762
Epoch: 020/050 | Batch 450/469 | Cost: 15144.8867
Time elapsed: 2.81 min
Epoch: 021/050 | Batch 000/469 | Cost: 14497.8389
Epoch: 021/050 | Batch 050/469 | Cost: 14999.7617
Epoch: 021/050 | Batch 100/469 | Cost: 14503.3086
Epoch: 021/050 | Batch 150/469 | Cost: 15366.3564
Epoch: 021/050 | Batch 200/469 | Cost: 15190.8740
Epoch: 021/050 | Batch 250/469 | Cost: 14832.3369
Epoch: 021/050 | Batch 300/469 | Cost: 15091.1016
Epoch: 021/050 | Batch 350/469 | Cost: 14928.2930
Epoch: 021/050 | Batch 400/469 | Cost: 14790.8223
Epoch: 021/050 | Batch 450/469 | Cost: 14803.0596
Time elapsed: 2.95 min
Epoch: 022/050 | Batch 000/469 | Cost: 14677.4326
Epoch: 022/050 | Batch 050/469 | Cost: 14652.6543
Epoch: 022/050 | Batch 100/469 | Cost: 15094.6904
Epoch: 022/050 | Batch 150/469 | Cost: 14702.5977
Epoch: 022/050 | Batch 200/469 | Cost: 15014.6758
Epoch: 022/050 | Batch 250/469 | Cost: 14506.5420
Epoch: 022/050 | Batch 300/469 | Cost: 14207.6309
Epoch: 022/050 | Batch 350/469 | Cost: 14883.4453
Epoch: 022/050 | Batch 400/469 | Cost: 14935.6797
Epoch: 022/050 | Batch 450/469 | Cost: 14522.0771
Time elapsed: 3.09 min
Epoch: 023/050 | Batch 000/469 | Cost: 14545.0410
Epoch: 023/050 | Batch 050/469 | Cost: 15465.3301
Epoch: 023/050 | Batch 100/469 | Cost: 14911.1807
Epoch: 023/050 | Batch 150/469 | Cost: 14108.9902
Epoch: 023/050 | Batch 200/469 | Cost: 14171.8672
Epoch: 023/050 | Batch 250/469 | Cost: 14510.0352
Epoch: 023/050 | Batch 300/469 | Cost: 14746.7100
Epoch: 023/050 | Batch 350/469 | Cost: 15409.6055
Epoch: 023/050 | Batch 400/469 | Cost: 14423.5654
Epoch: 023/050 | Batch 450/469 | Cost: 15278.3594
Time elapsed: 3.23 min
Epoch: 024/050 | Batch 000/469 | Cost: 14552.7031
Epoch: 024/050 | Batch 050/469 | Cost: 14798.7969
Epoch: 024/050 | Batch 100/469 | Cost: 14998.7012
Epoch: 024/050 | Batch 150/469 | Cost: 14323.0811
Epoch: 024/050 | Batch 200/469 | Cost: 13328.8086
Epoch: 024/050 | Batch 250/469 | Cost: 15235.0488
Epoch: 024/050 | Batch 300/469 | Cost: 14539.9482
Epoch: 024/050 | Batch 350/469 | Cost: 13984.4404
Epoch: 024/050 | Batch 400/469 | Cost: 14394.9082
Epoch: 024/050 | Batch 450/469 | Cost: 14836.1758
Time elapsed: 3.37 min
Epoch: 025/050 | Batch 000/469 | Cost: 14210.6611
Epoch: 025/050 | Batch 050/469 | Cost: 14331.7012
Epoch: 025/050 | Batch 100/469 | Cost: 14440.1592
Epoch: 025/050 | Batch 150/469 | Cost: 14585.4521
Epoch: 025/050 | Batch 200/469 | Cost: 14941.8232
Epoch: 025/050 | Batch 250/469 | Cost: 14408.6523
Epoch: 025/050 | Batch 300/469 | Cost: 13879.6191
Epoch: 025/050 | Batch 350/469 | Cost: 14163.3799
Epoch: 025/050 | Batch 400/469 | Cost: 15489.8164
Epoch: 025/050 | Batch 450/469 | Cost: 14584.5352
Time elapsed: 3.51 min
Epoch: 026/050 | Batch 000/469 | Cost: 14449.3213
Epoch: 026/050 | Batch 050/469 | Cost: 14182.0420
Epoch: 026/050 | Batch 100/469 | Cost: 14822.8936
Epoch: 026/050 | Batch 150/469 | Cost: 15550.9629
Epoch: 026/050 | Batch 200/469 | Cost: 14777.4414
Epoch: 026/050 | Batch 250/469 | Cost: 14844.9375
Epoch: 026/050 | Batch 300/469 | Cost: 14236.6016
Epoch: 026/050 | Batch 350/469 | Cost: 14573.4326
Epoch: 026/050 | Batch 400/469 | Cost: 14540.6592
Epoch: 026/050 | Batch 450/469 | Cost: 15272.1367
Time elapsed: 3.65 min
Epoch: 027/050 | Batch 000/469 | Cost: 14737.4766
Epoch: 027/050 | Batch 050/469 | Cost: 14636.1719
Epoch: 027/050 | Batch 100/469 | Cost: 14763.8066
Epoch: 027/050 | Batch 150/469 | Cost: 14228.8965
Epoch: 027/050 | Batch 200/469 | Cost: 14508.6289
Epoch: 027/050 | Batch 250/469 | Cost: 14433.5488
Epoch: 027/050 | Batch 300/469 | Cost: 14199.0078
Epoch: 027/050 | Batch 350/469 | Cost: 14910.3555
Epoch: 027/050 | Batch 400/469 | Cost: 14825.3359
Epoch: 027/050 | Batch 450/469 | Cost: 14556.9355
Time elapsed: 3.79 min
Epoch: 028/050 | Batch 000/469 | Cost: 14801.7754
Epoch: 028/050 | Batch 050/469 | Cost: 14283.8076
Epoch: 028/050 | Batch 100/469 | Cost: 14157.8916
Epoch: 028/050 | Batch 150/469 | Cost: 14591.0586
Epoch: 028/050 | Batch 200/469 | Cost: 14707.6934
Epoch: 028/050 | Batch 250/469 | Cost: 14730.5000
Epoch: 028/050 | Batch 300/469 | Cost: 14761.3613
Epoch: 028/050 | Batch 350/469 | Cost: 15279.7812
Epoch: 028/050 | Batch 400/469 | Cost: 14528.2744
Epoch: 028/050 | Batch 450/469 | Cost: 14167.2188
Time elapsed: 3.93 min
Epoch: 029/050 | Batch 000/469 | Cost: 14382.7207
Epoch: 029/050 | Batch 050/469 | Cost: 15143.0254
Epoch: 029/050 | Batch 100/469 | Cost: 14207.4375
Epoch: 029/050 | Batch 150/469 | Cost: 15312.8730
Epoch: 029/050 | Batch 200/469 | Cost: 14714.6807
Epoch: 029/050 | Batch 250/469 | Cost: 14761.9023
Epoch: 029/050 | Batch 300/469 | Cost: 13909.5557
Epoch: 029/050 | Batch 350/469 | Cost: 15295.2285
Epoch: 029/050 | Batch 400/469 | Cost: 14590.0059
Epoch: 029/050 | Batch 450/469 | Cost: 13771.6270
Time elapsed: 4.07 min
Epoch: 030/050 | Batch 000/469 | Cost: 14302.2412
Epoch: 030/050 | Batch 050/469 | Cost: 14636.1582
Epoch: 030/050 | Batch 100/469 | Cost: 14535.5391
Epoch: 030/050 | Batch 150/469 | Cost: 14794.7129
Epoch: 030/050 | Batch 200/469 | Cost: 14745.2432
Epoch: 030/050 | Batch 250/469 | Cost: 14465.8652
Epoch: 030/050 | Batch 300/469 | Cost: 14903.6123
Epoch: 030/050 | Batch 350/469 | Cost: 14062.1025
Epoch: 030/050 | Batch 400/469 | Cost: 14659.8281
Epoch: 030/050 | Batch 450/469 | Cost: 14638.7471
Time elapsed: 4.21 min
Epoch: 031/050 | Batch 000/469 | Cost: 13900.5020
Epoch: 031/050 | Batch 050/469 | Cost: 14276.7793
Epoch: 031/050 | Batch 100/469 | Cost: 14385.0371
Epoch: 031/050 | Batch 150/469 | Cost: 15063.9482
Epoch: 031/050 | Batch 200/469 | Cost: 14061.3789
Epoch: 031/050 | Batch 250/469 | Cost: 14794.1172
Epoch: 031/050 | Batch 300/469 | Cost: 14461.4004
Epoch: 031/050 | Batch 350/469 | Cost: 14760.6582
Epoch: 031/050 | Batch 400/469 | Cost: 14211.6348
Epoch: 031/050 | Batch 450/469 | Cost: 15117.2490
Time elapsed: 4.36 min
Epoch: 032/050 | Batch 000/469 | Cost: 14433.7568
Epoch: 032/050 | Batch 050/469 | Cost: 14379.6641
Epoch: 032/050 | Batch 100/469 | Cost: 14304.6387
Epoch: 032/050 | Batch 150/469 | Cost: 13829.6826
Epoch: 032/050 | Batch 200/469 | Cost: 14619.5654
Epoch: 032/050 | Batch 250/469 | Cost: 14488.1992
Epoch: 032/050 | Batch 300/469 | Cost: 14025.6309
Epoch: 032/050 | Batch 350/469 | Cost: 14557.8555
Epoch: 032/050 | Batch 400/469 | Cost: 14625.9219
Epoch: 032/050 | Batch 450/469 | Cost: 14467.3330
Time elapsed: 4.50 min
Epoch: 033/050 | Batch 000/469 | Cost: 13708.5605
Epoch: 033/050 | Batch 050/469 | Cost: 14030.7461
Epoch: 033/050 | Batch 100/469 | Cost: 15058.2783
Epoch: 033/050 | Batch 150/469 | Cost: 14089.2373
Epoch: 033/050 | Batch 200/469 | Cost: 14830.2188
Epoch: 033/050 | Batch 250/469 | Cost: 14473.9287
Epoch: 033/050 | Batch 300/469 | Cost: 14349.3984
Epoch: 033/050 | Batch 350/469 | Cost: 14528.9199
Epoch: 033/050 | Batch 400/469 | Cost: 14033.7891
Epoch: 033/050 | Batch 450/469 | Cost: 14026.8301
Time elapsed: 4.64 min
Epoch: 034/050 | Batch 000/469 | Cost: 15065.5000
Epoch: 034/050 | Batch 050/469 | Cost: 14807.9961
Epoch: 034/050 | Batch 100/469 | Cost: 14439.8008
Epoch: 034/050 | Batch 150/469 | Cost: 14711.3418
Epoch: 034/050 | Batch 200/469 | Cost: 14689.3828
Epoch: 034/050 | Batch 250/469 | Cost: 13956.6719
Epoch: 034/050 | Batch 300/469 | Cost: 14398.5410
Epoch: 034/050 | Batch 350/469 | Cost: 14900.2051
Epoch: 034/050 | Batch 400/469 | Cost: 14035.2871
Epoch: 034/050 | Batch 450/469 | Cost: 14370.9922
Time elapsed: 4.78 min
Epoch: 035/050 | Batch 000/469 | Cost: 14394.5488
Epoch: 035/050 | Batch 050/469 | Cost: 14367.2725
Epoch: 035/050 | Batch 100/469 | Cost: 14434.9248
Epoch: 035/050 | Batch 150/469 | Cost: 14409.7148
Epoch: 035/050 | Batch 200/469 | Cost: 14353.3174
Epoch: 035/050 | Batch 250/469 | Cost: 14548.2354
Epoch: 035/050 | Batch 300/469 | Cost: 14818.1543
Epoch: 035/050 | Batch 350/469 | Cost: 13898.6777
Epoch: 035/050 | Batch 400/469 | Cost: 14176.9395
Epoch: 035/050 | Batch 450/469 | Cost: 13999.2061
Time elapsed: 4.91 min
Epoch: 036/050 | Batch 000/469 | Cost: 14288.9336
Epoch: 036/050 | Batch 050/469 | Cost: 14487.9365
Epoch: 036/050 | Batch 100/469 | Cost: 14154.0234
Epoch: 036/050 | Batch 150/469 | Cost: 14574.0762
Epoch: 036/050 | Batch 200/469 | Cost: 14200.3008
Epoch: 036/050 | Batch 250/469 | Cost: 14022.4297
Epoch: 036/050 | Batch 300/469 | Cost: 14053.5713
Epoch: 036/050 | Batch 350/469 | Cost: 14348.0186
Epoch: 036/050 | Batch 400/469 | Cost: 14567.2314
Epoch: 036/050 | Batch 450/469 | Cost: 14527.6348
Time elapsed: 5.05 min
Epoch: 037/050 | Batch 000/469 | Cost: 14948.3877
Epoch: 037/050 | Batch 050/469 | Cost: 14357.0439
Epoch: 037/050 | Batch 100/469 | Cost: 13578.9121
Epoch: 037/050 | Batch 150/469 | Cost: 14657.7266
Epoch: 037/050 | Batch 200/469 | Cost: 14293.0732
Epoch: 037/050 | Batch 250/469 | Cost: 13609.5859
Epoch: 037/050 | Batch 300/469 | Cost: 13738.5283
Epoch: 037/050 | Batch 350/469 | Cost: 14079.2803
Epoch: 037/050 | Batch 400/469 | Cost: 14029.6797
Epoch: 037/050 | Batch 450/469 | Cost: 14522.6406
Time elapsed: 5.18 min
Epoch: 038/050 | Batch 000/469 | Cost: 14005.6035
Epoch: 038/050 | Batch 050/469 | Cost: 13756.8330
Epoch: 038/050 | Batch 100/469 | Cost: 15247.8760
Epoch: 038/050 | Batch 150/469 | Cost: 14034.3789
Epoch: 038/050 | Batch 200/469 | Cost: 14204.7061
Epoch: 038/050 | Batch 250/469 | Cost: 14023.4863
Epoch: 038/050 | Batch 300/469 | Cost: 13636.5508
Epoch: 038/050 | Batch 350/469 | Cost: 14509.3711
Epoch: 038/050 | Batch 400/469 | Cost: 14496.3965
Epoch: 038/050 | Batch 450/469 | Cost: 14460.8896
Time elapsed: 5.32 min
Epoch: 039/050 | Batch 000/469 | Cost: 14317.6602
Epoch: 039/050 | Batch 050/469 | Cost: 14440.6855
Epoch: 039/050 | Batch 100/469 | Cost: 13772.3691
Epoch: 039/050 | Batch 150/469 | Cost: 14023.2480
Epoch: 039/050 | Batch 200/469 | Cost: 14576.5449
Epoch: 039/050 | Batch 250/469 | Cost: 14164.7266
Epoch: 039/050 | Batch 300/469 | Cost: 13657.8369
Epoch: 039/050 | Batch 350/469 | Cost: 14456.4014
Epoch: 039/050 | Batch 400/469 | Cost: 14202.3047
Epoch: 039/050 | Batch 450/469 | Cost: 14564.9531
Time elapsed: 5.45 min
Epoch: 040/050 | Batch 000/469 | Cost: 14392.9277
Epoch: 040/050 | Batch 050/469 | Cost: 13708.4375
Epoch: 040/050 | Batch 100/469 | Cost: 14689.3535
Epoch: 040/050 | Batch 150/469 | Cost: 13887.5840
Epoch: 040/050 | Batch 200/469 | Cost: 14047.1543
Epoch: 040/050 | Batch 250/469 | Cost: 14142.5859
Epoch: 040/050 | Batch 300/469 | Cost: 14016.5820
Epoch: 040/050 | Batch 350/469 | Cost: 14962.1387
Epoch: 040/050 | Batch 400/469 | Cost: 14433.1416
Epoch: 040/050 | Batch 450/469 | Cost: 14622.0762
Time elapsed: 5.60 min
Epoch: 041/050 | Batch 000/469 | Cost: 15024.6074
Epoch: 041/050 | Batch 050/469 | Cost: 14015.1895
Epoch: 041/050 | Batch 100/469 | Cost: 14236.3535
Epoch: 041/050 | Batch 150/469 | Cost: 13553.2012
Epoch: 041/050 | Batch 200/469 | Cost: 14393.0205
Epoch: 041/050 | Batch 250/469 | Cost: 14220.9316
Epoch: 041/050 | Batch 300/469 | Cost: 13906.4434
Epoch: 041/050 | Batch 350/469 | Cost: 13650.9873
Epoch: 041/050 | Batch 400/469 | Cost: 14031.2979
Epoch: 041/050 | Batch 450/469 | Cost: 14202.7402
Time elapsed: 5.74 min
Epoch: 042/050 | Batch 000/469 | Cost: 13856.5684
Epoch: 042/050 | Batch 050/469 | Cost: 14359.9023
Epoch: 042/050 | Batch 100/469 | Cost: 14294.4902
Epoch: 042/050 | Batch 150/469 | Cost: 14577.5811
Epoch: 042/050 | Batch 200/469 | Cost: 14028.5820
Epoch: 042/050 | Batch 250/469 | Cost: 13892.3926
Epoch: 042/050 | Batch 300/469 | Cost: 13972.0322
Epoch: 042/050 | Batch 350/469 | Cost: 14635.3506
Epoch: 042/050 | Batch 400/469 | Cost: 13453.1562
Epoch: 042/050 | Batch 450/469 | Cost: 14930.7197
Time elapsed: 5.88 min
Epoch: 043/050 | Batch 000/469 | Cost: 14080.6318
Epoch: 043/050 | Batch 050/469 | Cost: 14356.2100
Epoch: 043/050 | Batch 100/469 | Cost: 14747.7344
Epoch: 043/050 | Batch 150/469 | Cost: 14025.0693
Epoch: 043/050 | Batch 200/469 | Cost: 14294.0615
Epoch: 043/050 | Batch 250/469 | Cost: 14147.0391
Epoch: 043/050 | Batch 300/469 | Cost: 14254.3008
Epoch: 043/050 | Batch 350/469 | Cost: 13503.6582
Epoch: 043/050 | Batch 400/469 | Cost: 14689.1816
Epoch: 043/050 | Batch 450/469 | Cost: 14308.2051
Time elapsed: 6.02 min
Epoch: 044/050 | Batch 000/469 | Cost: 13875.5928
Epoch: 044/050 | Batch 050/469 | Cost: 14699.4229
Epoch: 044/050 | Batch 100/469 | Cost: 14394.9424
Epoch: 044/050 | Batch 150/469 | Cost: 14657.7197
Epoch: 044/050 | Batch 200/469 | Cost: 14011.2949
Epoch: 044/050 | Batch 250/469 | Cost: 13314.2246
Epoch: 044/050 | Batch 300/469 | Cost: 14493.9434
Epoch: 044/050 | Batch 350/469 | Cost: 13947.0000
Epoch: 044/050 | Batch 400/469 | Cost: 14538.6055
Epoch: 044/050 | Batch 450/469 | Cost: 13822.2129
Time elapsed: 6.16 min
Epoch: 045/050 | Batch 000/469 | Cost: 14430.2080
Epoch: 045/050 | Batch 050/469 | Cost: 13560.6621
Epoch: 045/050 | Batch 100/469 | Cost: 14101.0293
Epoch: 045/050 | Batch 150/469 | Cost: 13972.5605
Epoch: 045/050 | Batch 200/469 | Cost: 13934.4883
Epoch: 045/050 | Batch 250/469 | Cost: 14146.7676
Epoch: 045/050 | Batch 300/469 | Cost: 14229.7588
Epoch: 045/050 | Batch 350/469 | Cost: 14473.1758
Epoch: 045/050 | Batch 400/469 | Cost: 14182.4443
Epoch: 045/050 | Batch 450/469 | Cost: 13847.8311
Time elapsed: 6.30 min
Epoch: 046/050 | Batch 000/469 | Cost: 13579.7725
Epoch: 046/050 | Batch 050/469 | Cost: 14197.9629
Epoch: 046/050 | Batch 100/469 | Cost: 14378.0156
Epoch: 046/050 | Batch 150/469 | Cost: 13889.5391
Epoch: 046/050 | Batch 200/469 | Cost: 14234.4473
Epoch: 046/050 | Batch 250/469 | Cost: 14565.4922
Epoch: 046/050 | Batch 300/469 | Cost: 14121.4434
Epoch: 046/050 | Batch 350/469 | Cost: 13544.7070
Epoch: 046/050 | Batch 400/469 | Cost: 13669.2461
Epoch: 046/050 | Batch 450/469 | Cost: 14321.6992
Time elapsed: 6.45 min
Epoch: 047/050 | Batch 000/469 | Cost: 14563.6592
Epoch: 047/050 | Batch 050/469 | Cost: 14157.8525
Epoch: 047/050 | Batch 100/469 | Cost: 14169.9375
Epoch: 047/050 | Batch 150/469 | Cost: 14047.9561
Epoch: 047/050 | Batch 200/469 | Cost: 14237.7090
Epoch: 047/050 | Batch 250/469 | Cost: 14265.3633
Epoch: 047/050 | Batch 300/469 | Cost: 14120.1963
Epoch: 047/050 | Batch 350/469 | Cost: 13613.9072
Epoch: 047/050 | Batch 400/469 | Cost: 13844.0146
Epoch: 047/050 | Batch 450/469 | Cost: 13815.9531
Time elapsed: 6.59 min
Epoch: 048/050 | Batch 000/469 | Cost: 14768.5332
Epoch: 048/050 | Batch 050/469 | Cost: 13807.6055
Epoch: 048/050 | Batch 100/469 | Cost: 14027.3555
Epoch: 048/050 | Batch 150/469 | Cost: 14198.5234
Epoch: 048/050 | Batch 200/469 | Cost: 14043.7871
Epoch: 048/050 | Batch 250/469 | Cost: 14150.2158
Epoch: 048/050 | Batch 300/469 | Cost: 14136.1113
Epoch: 048/050 | Batch 350/469 | Cost: 13921.3516
Epoch: 048/050 | Batch 400/469 | Cost: 14452.8145
Epoch: 048/050 | Batch 450/469 | Cost: 13998.9541
Time elapsed: 6.73 min
Epoch: 049/050 | Batch 000/469 | Cost: 14730.7822
Epoch: 049/050 | Batch 050/469 | Cost: 14744.3809
Epoch: 049/050 | Batch 100/469 | Cost: 14377.9961
Epoch: 049/050 | Batch 150/469 | Cost: 13894.9863
Epoch: 049/050 | Batch 200/469 | Cost: 14319.2900
Epoch: 049/050 | Batch 250/469 | Cost: 14335.9785
Epoch: 049/050 | Batch 300/469 | Cost: 14045.4326
Epoch: 049/050 | Batch 350/469 | Cost: 14342.3359
Epoch: 049/050 | Batch 400/469 | Cost: 13990.9199
Epoch: 049/050 | Batch 450/469 | Cost: 13979.7559
Time elapsed: 6.87 min
Epoch: 050/050 | Batch 000/469 | Cost: 13741.7539
Epoch: 050/050 | Batch 050/469 | Cost: 14258.0557
Epoch: 050/050 | Batch 100/469 | Cost: 14187.6738
Epoch: 050/050 | Batch 150/469 | Cost: 14332.6895
Epoch: 050/050 | Batch 200/469 | Cost: 14304.8984
Epoch: 050/050 | Batch 250/469 | Cost: 13983.0000
Epoch: 050/050 | Batch 300/469 | Cost: 14277.3750
Epoch: 050/050 | Batch 350/469 | Cost: 13838.4023
Epoch: 050/050 | Batch 400/469 | Cost: 13978.5732
Epoch: 050/050 | Batch 450/469 | Cost: 13924.4717
Time elapsed: 7.01 min
Total Training Time: 7.01 min

Evaluation

Reconstruction

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt

##########################
### VISUALIZATION
##########################

n_images = 15
image_width = 28

fig, axes = plt.subplots(nrows=2, ncols=n_images, 
                         sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images]

for i in range(n_images):
    for ax, img in zip(axes, [orig_images, decoded_images]):
        ax[i].imshow(img[i].detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')

Generate new images

In [7]:
for i in range(10):

    ##########################
    ### RANDOM SAMPLE
    ##########################    
    
    n_images = 10
    rand_features = torch.randn(n_images, num_latent).to(device)
    new_images = model.decoder(rand_features)

    ##########################
    ### VISUALIZATION
    ##########################

    image_width = 28

    fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
    decoded_images = new_images[:n_images]

    for ax, img in zip(axes, decoded_images):
        ax.imshow(img.detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')
        
    plt.show()
In [8]:
%watermark -iv
numpy       1.15.4
torch       1.0.0