Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.7.1
IPython 7.2.0

torch 1.0.0
  • Runs on CPU or GPU (if available)

Model Zoo -- Conditional Variational Autoencoder

(without labels in reconstruction loss)

A simple conditional variational autoencoder that compresses 768-pixel MNIST images down to a 35-pixel latent vector representation.

This implementation DOES NOT concatenate the inputs with the class labels when computing the reconstruction loss in contrast to how it is commonly done in non-convolutional conditional variational autoencoders. Not considering class-labels in the reconstruction loss leads to slightly worse results compared to the implementation that does concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cvae.ipynb

Imports

In [2]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
##########################
### SETTINGS
##########################

# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)

# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128

# Architecture
num_classes = 10
num_features = 784
num_hidden_1 = 500
num_latent = 35


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Device: cuda:1
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [4]:
##########################
### MODEL
##########################


def to_onehot(labels, num_classes, device):

    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)

    return labels_onehot


class ConditionalVariationalAutoencoder(torch.nn.Module):

    def __init__(self, num_features, num_hidden_1, num_latent, num_classes):
        super(ConditionalVariationalAutoencoder, self).__init__()
        
        self.num_classes = num_classes
        
        ### ENCODER
        self.hidden_1 = torch.nn.Linear(num_features+num_classes, num_hidden_1)
        self.z_mean = torch.nn.Linear(num_hidden_1, num_latent)
        # in the original paper (Kingma & Welling 2015, we use
        # have a z_mean and z_var, but the problem is that
        # the z_var can be negative, which would cause issues
        # in the log later. Hence we assume that latent vector
        # has a z_mean and z_log_var component, and when we need
        # the regular variance or std_dev, we simply use 
        # an exponential function
        self.z_log_var = torch.nn.Linear(num_hidden_1, num_latent)
        
        
        ### DECODER
        self.linear_3 = torch.nn.Linear(num_latent+num_classes, num_hidden_1)
        # don't output labels in resulting image as it yields worse results
        #self.linear_4 = torch.nn.Linear(num_hidden_1, num_features+num_classes)
        self.linear_4 = torch.nn.Linear(num_hidden_1, num_features)

    def reparameterize(self, z_mu, z_log_var):
        # Sample epsilon from standard normal distribution
        eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
        # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
        # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z
        
    def encoder(self, features, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        x = torch.cat((features, onehot_targets), dim=1)

        ### ENCODER
        x = self.hidden_1(x)
        x = F.leaky_relu(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)
        return z_mean, z_log_var, encoded
    
    def decoder(self, encoded, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        encoded = torch.cat((encoded, onehot_targets), dim=1)        
        
        ### DECODER
        x = self.linear_3(encoded)
        x = F.leaky_relu(x)
        x = self.linear_4(x)
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, features, targets):
        
        z_mean, z_log_var, encoded = self.encoder(features, targets)
        decoded = self.decoder(encoded, targets)
        
        return z_mean, z_log_var, encoded, decoded

    
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
                                          num_hidden_1,
                                          num_latent,
                                          num_classes)
model = model.to(device)
    

##########################
### COST AND OPTIMIZER
##########################

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Training

In [5]:
start_time = time.time()
for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.view(-1, 28*28).to(device)
        targets = targets.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features, targets)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        
        ### Add condition
        # Disabled for reconstruction loss as it gives poor results
        # x_con = torch.cat((features, to_onehot(targets, num_classes, device)), dim=1)
        
        ### Compute loss
        # pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
        pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        ### UPDATE MODEL PARAMETERS
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader), cost))
            
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 70192.2188
Epoch: 001/050 | Batch 050/469 | Cost: 25749.1211
Epoch: 001/050 | Batch 100/469 | Cost: 21997.0703
Epoch: 001/050 | Batch 150/469 | Cost: 19867.7832
Epoch: 001/050 | Batch 200/469 | Cost: 19859.0391
Epoch: 001/050 | Batch 250/469 | Cost: 18637.0820
Epoch: 001/050 | Batch 300/469 | Cost: 17863.7227
Epoch: 001/050 | Batch 350/469 | Cost: 16946.2188
Epoch: 001/050 | Batch 400/469 | Cost: 16602.0566
Epoch: 001/050 | Batch 450/469 | Cost: 16563.0742
Epoch: 002/050 | Batch 000/469 | Cost: 15998.6934
Epoch: 002/050 | Batch 050/469 | Cost: 15980.9590
Epoch: 002/050 | Batch 100/469 | Cost: 15842.0508
Epoch: 002/050 | Batch 150/469 | Cost: 15789.3096
Epoch: 002/050 | Batch 200/469 | Cost: 15471.0068
Epoch: 002/050 | Batch 250/469 | Cost: 15047.9609
Epoch: 002/050 | Batch 300/469 | Cost: 14812.4609
Epoch: 002/050 | Batch 350/469 | Cost: 15064.4570
Epoch: 002/050 | Batch 400/469 | Cost: 14684.2832
Epoch: 002/050 | Batch 450/469 | Cost: 14621.3662
Epoch: 003/050 | Batch 000/469 | Cost: 14662.3740
Epoch: 003/050 | Batch 050/469 | Cost: 14373.9258
Epoch: 003/050 | Batch 100/469 | Cost: 14580.2539
Epoch: 003/050 | Batch 150/469 | Cost: 14757.9639
Epoch: 003/050 | Batch 200/469 | Cost: 14678.1953
Epoch: 003/050 | Batch 250/469 | Cost: 14471.2031
Epoch: 003/050 | Batch 300/469 | Cost: 14082.8926
Epoch: 003/050 | Batch 350/469 | Cost: 14371.0566
Epoch: 003/050 | Batch 400/469 | Cost: 13371.3496
Epoch: 003/050 | Batch 450/469 | Cost: 14689.0518
Epoch: 004/050 | Batch 000/469 | Cost: 13579.7764
Epoch: 004/050 | Batch 050/469 | Cost: 14012.3291
Epoch: 004/050 | Batch 100/469 | Cost: 13688.7236
Epoch: 004/050 | Batch 150/469 | Cost: 13726.9590
Epoch: 004/050 | Batch 200/469 | Cost: 13938.3027
Epoch: 004/050 | Batch 250/469 | Cost: 14040.9287
Epoch: 004/050 | Batch 300/469 | Cost: 13934.2998
Epoch: 004/050 | Batch 350/469 | Cost: 13701.2197
Epoch: 004/050 | Batch 400/469 | Cost: 13695.0098
Epoch: 004/050 | Batch 450/469 | Cost: 13170.8828
Epoch: 005/050 | Batch 000/469 | Cost: 13500.4805
Epoch: 005/050 | Batch 050/469 | Cost: 13655.4971
Epoch: 005/050 | Batch 100/469 | Cost: 13458.8867
Epoch: 005/050 | Batch 150/469 | Cost: 13754.9385
Epoch: 005/050 | Batch 200/469 | Cost: 13421.4209
Epoch: 005/050 | Batch 250/469 | Cost: 13213.7803
Epoch: 005/050 | Batch 300/469 | Cost: 12693.9590
Epoch: 005/050 | Batch 350/469 | Cost: 13030.9766
Epoch: 005/050 | Batch 400/469 | Cost: 13811.0107
Epoch: 005/050 | Batch 450/469 | Cost: 14092.8613
Epoch: 006/050 | Batch 000/469 | Cost: 13308.8340
Epoch: 006/050 | Batch 050/469 | Cost: 13082.6172
Epoch: 006/050 | Batch 100/469 | Cost: 13904.7197
Epoch: 006/050 | Batch 150/469 | Cost: 13171.1230
Epoch: 006/050 | Batch 200/469 | Cost: 13504.8125
Epoch: 006/050 | Batch 250/469 | Cost: 13535.9785
Epoch: 006/050 | Batch 300/469 | Cost: 13284.6123
Epoch: 006/050 | Batch 350/469 | Cost: 13442.9844
Epoch: 006/050 | Batch 400/469 | Cost: 13444.6738
Epoch: 006/050 | Batch 450/469 | Cost: 13511.4160
Epoch: 007/050 | Batch 000/469 | Cost: 13908.5977
Epoch: 007/050 | Batch 050/469 | Cost: 13451.8223
Epoch: 007/050 | Batch 100/469 | Cost: 13165.7402
Epoch: 007/050 | Batch 150/469 | Cost: 13328.9355
Epoch: 007/050 | Batch 200/469 | Cost: 12998.3633
Epoch: 007/050 | Batch 250/469 | Cost: 13683.5605
Epoch: 007/050 | Batch 300/469 | Cost: 13152.0830
Epoch: 007/050 | Batch 350/469 | Cost: 13329.2920
Epoch: 007/050 | Batch 400/469 | Cost: 13330.4443
Epoch: 007/050 | Batch 450/469 | Cost: 13523.7051
Epoch: 008/050 | Batch 000/469 | Cost: 13326.9102
Epoch: 008/050 | Batch 050/469 | Cost: 12950.0498
Epoch: 008/050 | Batch 100/469 | Cost: 13583.9219
Epoch: 008/050 | Batch 150/469 | Cost: 12776.9805
Epoch: 008/050 | Batch 200/469 | Cost: 12999.6387
Epoch: 008/050 | Batch 250/469 | Cost: 13324.9883
Epoch: 008/050 | Batch 300/469 | Cost: 13418.3408
Epoch: 008/050 | Batch 350/469 | Cost: 13043.9551
Epoch: 008/050 | Batch 400/469 | Cost: 13222.0293
Epoch: 008/050 | Batch 450/469 | Cost: 12615.9102
Epoch: 009/050 | Batch 000/469 | Cost: 13419.1953
Epoch: 009/050 | Batch 050/469 | Cost: 13427.7383
Epoch: 009/050 | Batch 100/469 | Cost: 12853.0498
Epoch: 009/050 | Batch 150/469 | Cost: 13082.6865
Epoch: 009/050 | Batch 200/469 | Cost: 12926.8877
Epoch: 009/050 | Batch 250/469 | Cost: 13223.6982
Epoch: 009/050 | Batch 300/469 | Cost: 12966.2803
Epoch: 009/050 | Batch 350/469 | Cost: 12672.2607
Epoch: 009/050 | Batch 400/469 | Cost: 13285.1992
Epoch: 009/050 | Batch 450/469 | Cost: 12638.7812
Epoch: 010/050 | Batch 000/469 | Cost: 13139.2168
Epoch: 010/050 | Batch 050/469 | Cost: 12674.6816
Epoch: 010/050 | Batch 100/469 | Cost: 13080.3828
Epoch: 010/050 | Batch 150/469 | Cost: 12448.9199
Epoch: 010/050 | Batch 200/469 | Cost: 12761.3613
Epoch: 010/050 | Batch 250/469 | Cost: 13139.7520
Epoch: 010/050 | Batch 300/469 | Cost: 12969.9932
Epoch: 010/050 | Batch 350/469 | Cost: 12518.5615
Epoch: 010/050 | Batch 400/469 | Cost: 13042.4551
Epoch: 010/050 | Batch 450/469 | Cost: 13296.8926
Epoch: 011/050 | Batch 000/469 | Cost: 13233.5322
Epoch: 011/050 | Batch 050/469 | Cost: 13085.0918
Epoch: 011/050 | Batch 100/469 | Cost: 12664.7422
Epoch: 011/050 | Batch 150/469 | Cost: 13344.7686
Epoch: 011/050 | Batch 200/469 | Cost: 12498.5938
Epoch: 011/050 | Batch 250/469 | Cost: 13314.7920
Epoch: 011/050 | Batch 300/469 | Cost: 13175.9463
Epoch: 011/050 | Batch 350/469 | Cost: 13034.4180
Epoch: 011/050 | Batch 400/469 | Cost: 12425.1221
Epoch: 011/050 | Batch 450/469 | Cost: 12548.4668
Epoch: 012/050 | Batch 000/469 | Cost: 12779.4053
Epoch: 012/050 | Batch 050/469 | Cost: 13129.1328
Epoch: 012/050 | Batch 100/469 | Cost: 12274.7422
Epoch: 012/050 | Batch 150/469 | Cost: 13289.4688
Epoch: 012/050 | Batch 200/469 | Cost: 13256.5312
Epoch: 012/050 | Batch 250/469 | Cost: 12437.4629
Epoch: 012/050 | Batch 300/469 | Cost: 12500.7627
Epoch: 012/050 | Batch 350/469 | Cost: 13362.0430
Epoch: 012/050 | Batch 400/469 | Cost: 13271.1768
Epoch: 012/050 | Batch 450/469 | Cost: 13070.1992
Epoch: 013/050 | Batch 000/469 | Cost: 12979.0723
Epoch: 013/050 | Batch 050/469 | Cost: 12714.0527
Epoch: 013/050 | Batch 100/469 | Cost: 12925.5879
Epoch: 013/050 | Batch 150/469 | Cost: 13068.3555
Epoch: 013/050 | Batch 200/469 | Cost: 12462.0791
Epoch: 013/050 | Batch 250/469 | Cost: 12443.1250
Epoch: 013/050 | Batch 300/469 | Cost: 12773.1631
Epoch: 013/050 | Batch 350/469 | Cost: 12435.6836
Epoch: 013/050 | Batch 400/469 | Cost: 12659.2441
Epoch: 013/050 | Batch 450/469 | Cost: 12680.4297
Epoch: 014/050 | Batch 000/469 | Cost: 12963.3291
Epoch: 014/050 | Batch 050/469 | Cost: 12406.1680
Epoch: 014/050 | Batch 100/469 | Cost: 13342.7998
Epoch: 014/050 | Batch 150/469 | Cost: 13050.4004
Epoch: 014/050 | Batch 200/469 | Cost: 12695.7129
Epoch: 014/050 | Batch 250/469 | Cost: 12899.9678
Epoch: 014/050 | Batch 300/469 | Cost: 12568.9746
Epoch: 014/050 | Batch 350/469 | Cost: 12800.3164
Epoch: 014/050 | Batch 400/469 | Cost: 12908.6758
Epoch: 014/050 | Batch 450/469 | Cost: 13055.2197
Epoch: 015/050 | Batch 000/469 | Cost: 12697.0527
Epoch: 015/050 | Batch 050/469 | Cost: 13206.3340
Epoch: 015/050 | Batch 100/469 | Cost: 12505.6865
Epoch: 015/050 | Batch 150/469 | Cost: 12765.6504
Epoch: 015/050 | Batch 200/469 | Cost: 12692.5625
Epoch: 015/050 | Batch 250/469 | Cost: 12564.1904
Epoch: 015/050 | Batch 300/469 | Cost: 12480.1055
Epoch: 015/050 | Batch 350/469 | Cost: 12703.9590
Epoch: 015/050 | Batch 400/469 | Cost: 12782.6943
Epoch: 015/050 | Batch 450/469 | Cost: 12501.6982
Epoch: 016/050 | Batch 000/469 | Cost: 12316.8369
Epoch: 016/050 | Batch 050/469 | Cost: 12879.1367
Epoch: 016/050 | Batch 100/469 | Cost: 12799.4814
Epoch: 016/050 | Batch 150/469 | Cost: 13116.8818
Epoch: 016/050 | Batch 200/469 | Cost: 12788.3652
Epoch: 016/050 | Batch 250/469 | Cost: 12618.8379
Epoch: 016/050 | Batch 300/469 | Cost: 13378.3730
Epoch: 016/050 | Batch 350/469 | Cost: 12751.9121
Epoch: 016/050 | Batch 400/469 | Cost: 12654.6123
Epoch: 016/050 | Batch 450/469 | Cost: 12693.1211
Epoch: 017/050 | Batch 000/469 | Cost: 13261.9746
Epoch: 017/050 | Batch 050/469 | Cost: 13040.6025
Epoch: 017/050 | Batch 100/469 | Cost: 12892.2832
Epoch: 017/050 | Batch 150/469 | Cost: 12776.0957
Epoch: 017/050 | Batch 200/469 | Cost: 12676.0645
Epoch: 017/050 | Batch 250/469 | Cost: 13100.1250
Epoch: 017/050 | Batch 300/469 | Cost: 12229.5742
Epoch: 017/050 | Batch 350/469 | Cost: 12896.2207
Epoch: 017/050 | Batch 400/469 | Cost: 12986.7246
Epoch: 017/050 | Batch 450/469 | Cost: 12528.6777
Epoch: 018/050 | Batch 000/469 | Cost: 12395.8604
Epoch: 018/050 | Batch 050/469 | Cost: 12674.4678
Epoch: 018/050 | Batch 100/469 | Cost: 12528.5469
Epoch: 018/050 | Batch 150/469 | Cost: 13454.7070
Epoch: 018/050 | Batch 200/469 | Cost: 12878.5322
Epoch: 018/050 | Batch 250/469 | Cost: 12682.8457
Epoch: 018/050 | Batch 300/469 | Cost: 12604.6943
Epoch: 018/050 | Batch 350/469 | Cost: 13185.8828
Epoch: 018/050 | Batch 400/469 | Cost: 12933.7246
Epoch: 018/050 | Batch 450/469 | Cost: 12973.7314
Epoch: 019/050 | Batch 000/469 | Cost: 12347.1924
Epoch: 019/050 | Batch 050/469 | Cost: 12655.2314
Epoch: 019/050 | Batch 100/469 | Cost: 12840.0889
Epoch: 019/050 | Batch 150/469 | Cost: 12790.1152
Epoch: 019/050 | Batch 200/469 | Cost: 12546.3301
Epoch: 019/050 | Batch 250/469 | Cost: 12630.3662
Epoch: 019/050 | Batch 300/469 | Cost: 12877.1553
Epoch: 019/050 | Batch 350/469 | Cost: 12754.5049
Epoch: 019/050 | Batch 400/469 | Cost: 12562.9287
Epoch: 019/050 | Batch 450/469 | Cost: 12670.7939
Epoch: 020/050 | Batch 000/469 | Cost: 13078.5391
Epoch: 020/050 | Batch 050/469 | Cost: 13251.5137
Epoch: 020/050 | Batch 100/469 | Cost: 12222.6816
Epoch: 020/050 | Batch 150/469 | Cost: 13020.2549
Epoch: 020/050 | Batch 200/469 | Cost: 12660.7695
Epoch: 020/050 | Batch 250/469 | Cost: 12797.1309
Epoch: 020/050 | Batch 300/469 | Cost: 12559.7441
Epoch: 020/050 | Batch 350/469 | Cost: 12983.9473
Epoch: 020/050 | Batch 400/469 | Cost: 12665.8516
Epoch: 020/050 | Batch 450/469 | Cost: 12557.4512
Epoch: 021/050 | Batch 000/469 | Cost: 12259.2539
Epoch: 021/050 | Batch 050/469 | Cost: 12225.6787
Epoch: 021/050 | Batch 100/469 | Cost: 13265.6328
Epoch: 021/050 | Batch 150/469 | Cost: 12958.9795
Epoch: 021/050 | Batch 200/469 | Cost: 13201.1504
Epoch: 021/050 | Batch 250/469 | Cost: 12173.3027
Epoch: 021/050 | Batch 300/469 | Cost: 11880.8125
Epoch: 021/050 | Batch 350/469 | Cost: 12684.7500
Epoch: 021/050 | Batch 400/469 | Cost: 12973.6250
Epoch: 021/050 | Batch 450/469 | Cost: 12326.9854
Epoch: 022/050 | Batch 000/469 | Cost: 12506.0596
Epoch: 022/050 | Batch 050/469 | Cost: 12992.8047
Epoch: 022/050 | Batch 100/469 | Cost: 12908.5557
Epoch: 022/050 | Batch 150/469 | Cost: 12658.6768
Epoch: 022/050 | Batch 200/469 | Cost: 13097.6426
Epoch: 022/050 | Batch 250/469 | Cost: 12514.5166
Epoch: 022/050 | Batch 300/469 | Cost: 13067.9795
Epoch: 022/050 | Batch 350/469 | Cost: 13335.4814
Epoch: 022/050 | Batch 400/469 | Cost: 12482.6094
Epoch: 022/050 | Batch 450/469 | Cost: 12887.1328
Epoch: 023/050 | Batch 000/469 | Cost: 12895.0732
Epoch: 023/050 | Batch 050/469 | Cost: 12596.9219
Epoch: 023/050 | Batch 100/469 | Cost: 12961.1699
Epoch: 023/050 | Batch 150/469 | Cost: 12497.6240
Epoch: 023/050 | Batch 200/469 | Cost: 12390.3174
Epoch: 023/050 | Batch 250/469 | Cost: 12916.2070
Epoch: 023/050 | Batch 300/469 | Cost: 12608.6494
Epoch: 023/050 | Batch 350/469 | Cost: 12270.3037
Epoch: 023/050 | Batch 400/469 | Cost: 12774.8906
Epoch: 023/050 | Batch 450/469 | Cost: 12438.0068
Epoch: 024/050 | Batch 000/469 | Cost: 12060.3467
Epoch: 024/050 | Batch 050/469 | Cost: 12482.3770
Epoch: 024/050 | Batch 100/469 | Cost: 12389.7715
Epoch: 024/050 | Batch 150/469 | Cost: 13020.0859
Epoch: 024/050 | Batch 200/469 | Cost: 12233.1670
Epoch: 024/050 | Batch 250/469 | Cost: 12507.4473
Epoch: 024/050 | Batch 300/469 | Cost: 12403.1035
Epoch: 024/050 | Batch 350/469 | Cost: 12475.9551
Epoch: 024/050 | Batch 400/469 | Cost: 12369.6104
Epoch: 024/050 | Batch 450/469 | Cost: 12104.8066
Epoch: 025/050 | Batch 000/469 | Cost: 12380.4355
Epoch: 025/050 | Batch 050/469 | Cost: 12826.3662
Epoch: 025/050 | Batch 100/469 | Cost: 12431.5898
Epoch: 025/050 | Batch 150/469 | Cost: 12982.6113
Epoch: 025/050 | Batch 200/469 | Cost: 12823.1465
Epoch: 025/050 | Batch 250/469 | Cost: 12800.5156
Epoch: 025/050 | Batch 300/469 | Cost: 13140.7812
Epoch: 025/050 | Batch 350/469 | Cost: 12483.5723
Epoch: 025/050 | Batch 400/469 | Cost: 12694.3594
Epoch: 025/050 | Batch 450/469 | Cost: 12767.1543
Epoch: 026/050 | Batch 000/469 | Cost: 11855.4678
Epoch: 026/050 | Batch 050/469 | Cost: 12363.9590
Epoch: 026/050 | Batch 100/469 | Cost: 13079.2793
Epoch: 026/050 | Batch 150/469 | Cost: 12977.3594
Epoch: 026/050 | Batch 200/469 | Cost: 12642.0938
Epoch: 026/050 | Batch 250/469 | Cost: 12530.8447
Epoch: 026/050 | Batch 300/469 | Cost: 12514.3311
Epoch: 026/050 | Batch 350/469 | Cost: 12100.2314
Epoch: 026/050 | Batch 400/469 | Cost: 12814.0479
Epoch: 026/050 | Batch 450/469 | Cost: 12364.0166
Epoch: 027/050 | Batch 000/469 | Cost: 12499.8721
Epoch: 027/050 | Batch 050/469 | Cost: 12678.9111
Epoch: 027/050 | Batch 100/469 | Cost: 12261.5918
Epoch: 027/050 | Batch 150/469 | Cost: 12901.1641
Epoch: 027/050 | Batch 200/469 | Cost: 12548.0469
Epoch: 027/050 | Batch 250/469 | Cost: 12211.9111
Epoch: 027/050 | Batch 300/469 | Cost: 13003.7646
Epoch: 027/050 | Batch 350/469 | Cost: 12214.5781
Epoch: 027/050 | Batch 400/469 | Cost: 12604.0361
Epoch: 027/050 | Batch 450/469 | Cost: 12504.3213
Epoch: 028/050 | Batch 000/469 | Cost: 12680.8613
Epoch: 028/050 | Batch 050/469 | Cost: 13018.3525
Epoch: 028/050 | Batch 100/469 | Cost: 13040.8760
Epoch: 028/050 | Batch 150/469 | Cost: 12745.8643
Epoch: 028/050 | Batch 200/469 | Cost: 12417.4248
Epoch: 028/050 | Batch 250/469 | Cost: 12684.0645
Epoch: 028/050 | Batch 300/469 | Cost: 12119.3633
Epoch: 028/050 | Batch 350/469 | Cost: 12281.8008
Epoch: 028/050 | Batch 400/469 | Cost: 12434.8438
Epoch: 028/050 | Batch 450/469 | Cost: 12379.5928
Epoch: 029/050 | Batch 000/469 | Cost: 12527.9355
Epoch: 029/050 | Batch 050/469 | Cost: 12694.2578
Epoch: 029/050 | Batch 100/469 | Cost: 12318.5742
Epoch: 029/050 | Batch 150/469 | Cost: 12357.7070
Epoch: 029/050 | Batch 200/469 | Cost: 12823.2246
Epoch: 029/050 | Batch 250/469 | Cost: 12532.8555
Epoch: 029/050 | Batch 300/469 | Cost: 12343.1777
Epoch: 029/050 | Batch 350/469 | Cost: 12207.8662
Epoch: 029/050 | Batch 400/469 | Cost: 12553.4434
Epoch: 029/050 | Batch 450/469 | Cost: 12426.8096
Epoch: 030/050 | Batch 000/469 | Cost: 12391.7988
Epoch: 030/050 | Batch 050/469 | Cost: 12414.6650
Epoch: 030/050 | Batch 100/469 | Cost: 12213.8281
Epoch: 030/050 | Batch 150/469 | Cost: 12527.5752
Epoch: 030/050 | Batch 200/469 | Cost: 12135.3281
Epoch: 030/050 | Batch 250/469 | Cost: 12099.4062
Epoch: 030/050 | Batch 300/469 | Cost: 12891.4102
Epoch: 030/050 | Batch 350/469 | Cost: 12546.6768
Epoch: 030/050 | Batch 400/469 | Cost: 12653.6172
Epoch: 030/050 | Batch 450/469 | Cost: 12576.2285
Epoch: 031/050 | Batch 000/469 | Cost: 12499.4316
Epoch: 031/050 | Batch 050/469 | Cost: 12517.8770
Epoch: 031/050 | Batch 100/469 | Cost: 12340.2480
Epoch: 031/050 | Batch 150/469 | Cost: 12368.0469
Epoch: 031/050 | Batch 200/469 | Cost: 12331.4121
Epoch: 031/050 | Batch 250/469 | Cost: 12736.1953
Epoch: 031/050 | Batch 300/469 | Cost: 12985.6914
Epoch: 031/050 | Batch 350/469 | Cost: 12383.8086
Epoch: 031/050 | Batch 400/469 | Cost: 12270.4277
Epoch: 031/050 | Batch 450/469 | Cost: 12418.8633
Epoch: 032/050 | Batch 000/469 | Cost: 12244.7559
Epoch: 032/050 | Batch 050/469 | Cost: 12531.9453
Epoch: 032/050 | Batch 100/469 | Cost: 12477.5752
Epoch: 032/050 | Batch 150/469 | Cost: 12838.6650
Epoch: 032/050 | Batch 200/469 | Cost: 12590.4707
Epoch: 032/050 | Batch 250/469 | Cost: 12658.5674
Epoch: 032/050 | Batch 300/469 | Cost: 12619.9316
Epoch: 032/050 | Batch 350/469 | Cost: 12790.5488
Epoch: 032/050 | Batch 400/469 | Cost: 12336.5918
Epoch: 032/050 | Batch 450/469 | Cost: 11956.5361
Epoch: 033/050 | Batch 000/469 | Cost: 12257.5645
Epoch: 033/050 | Batch 050/469 | Cost: 12238.9277
Epoch: 033/050 | Batch 100/469 | Cost: 12166.1533
Epoch: 033/050 | Batch 150/469 | Cost: 12442.1953
Epoch: 033/050 | Batch 200/469 | Cost: 12383.0957
Epoch: 033/050 | Batch 250/469 | Cost: 12242.8730
Epoch: 033/050 | Batch 300/469 | Cost: 12493.3262
Epoch: 033/050 | Batch 350/469 | Cost: 12194.9941
Epoch: 033/050 | Batch 400/469 | Cost: 12441.2207
Epoch: 033/050 | Batch 450/469 | Cost: 12835.3838
Epoch: 034/050 | Batch 000/469 | Cost: 12413.8838
Epoch: 034/050 | Batch 050/469 | Cost: 12801.7031
Epoch: 034/050 | Batch 100/469 | Cost: 12464.5234
Epoch: 034/050 | Batch 150/469 | Cost: 12432.2822
Epoch: 034/050 | Batch 200/469 | Cost: 12561.4375
Epoch: 034/050 | Batch 250/469 | Cost: 12854.5889
Epoch: 034/050 | Batch 300/469 | Cost: 12125.7393
Epoch: 034/050 | Batch 350/469 | Cost: 12752.3701
Epoch: 034/050 | Batch 400/469 | Cost: 12496.3652
Epoch: 034/050 | Batch 450/469 | Cost: 12751.6465
Epoch: 035/050 | Batch 000/469 | Cost: 12277.0820
Epoch: 035/050 | Batch 050/469 | Cost: 12367.7256
Epoch: 035/050 | Batch 100/469 | Cost: 12402.5156
Epoch: 035/050 | Batch 150/469 | Cost: 12334.3750
Epoch: 035/050 | Batch 200/469 | Cost: 12532.5967
Epoch: 035/050 | Batch 250/469 | Cost: 12294.9727
Epoch: 035/050 | Batch 300/469 | Cost: 12221.8359
Epoch: 035/050 | Batch 350/469 | Cost: 12979.2939
Epoch: 035/050 | Batch 400/469 | Cost: 12789.4639
Epoch: 035/050 | Batch 450/469 | Cost: 12396.4160
Epoch: 036/050 | Batch 000/469 | Cost: 12536.0049
Epoch: 036/050 | Batch 050/469 | Cost: 12159.3613
Epoch: 036/050 | Batch 100/469 | Cost: 12361.6260
Epoch: 036/050 | Batch 150/469 | Cost: 12638.1709
Epoch: 036/050 | Batch 200/469 | Cost: 12634.9355
Epoch: 036/050 | Batch 250/469 | Cost: 12643.7432
Epoch: 036/050 | Batch 300/469 | Cost: 12563.5137
Epoch: 036/050 | Batch 350/469 | Cost: 12375.0566
Epoch: 036/050 | Batch 400/469 | Cost: 12551.1367
Epoch: 036/050 | Batch 450/469 | Cost: 12317.5762
Epoch: 037/050 | Batch 000/469 | Cost: 12063.4453
Epoch: 037/050 | Batch 050/469 | Cost: 11987.3984
Epoch: 037/050 | Batch 100/469 | Cost: 12577.7441
Epoch: 037/050 | Batch 150/469 | Cost: 12403.6309
Epoch: 037/050 | Batch 200/469 | Cost: 12922.1729
Epoch: 037/050 | Batch 250/469 | Cost: 12302.4805
Epoch: 037/050 | Batch 300/469 | Cost: 12353.8057
Epoch: 037/050 | Batch 350/469 | Cost: 12627.0859
Epoch: 037/050 | Batch 400/469 | Cost: 12517.3809
Epoch: 037/050 | Batch 450/469 | Cost: 11899.2090
Epoch: 038/050 | Batch 000/469 | Cost: 11766.3467
Epoch: 038/050 | Batch 050/469 | Cost: 12509.6875
Epoch: 038/050 | Batch 100/469 | Cost: 12706.8721
Epoch: 038/050 | Batch 150/469 | Cost: 12288.3730
Epoch: 038/050 | Batch 200/469 | Cost: 12531.9883
Epoch: 038/050 | Batch 250/469 | Cost: 12904.4297
Epoch: 038/050 | Batch 300/469 | Cost: 12279.0957
Epoch: 038/050 | Batch 350/469 | Cost: 13053.5732
Epoch: 038/050 | Batch 400/469 | Cost: 12317.9678
Epoch: 038/050 | Batch 450/469 | Cost: 12069.1924
Epoch: 039/050 | Batch 000/469 | Cost: 12420.7734
Epoch: 039/050 | Batch 050/469 | Cost: 12101.2764
Epoch: 039/050 | Batch 100/469 | Cost: 12663.4492
Epoch: 039/050 | Batch 150/469 | Cost: 12434.3320
Epoch: 039/050 | Batch 200/469 | Cost: 12394.2676
Epoch: 039/050 | Batch 250/469 | Cost: 12588.5234
Epoch: 039/050 | Batch 300/469 | Cost: 12016.5742
Epoch: 039/050 | Batch 350/469 | Cost: 11895.7480
Epoch: 039/050 | Batch 400/469 | Cost: 12270.1885
Epoch: 039/050 | Batch 450/469 | Cost: 12623.2764
Epoch: 040/050 | Batch 000/469 | Cost: 12347.0195
Epoch: 040/050 | Batch 050/469 | Cost: 12172.0439
Epoch: 040/050 | Batch 100/469 | Cost: 12112.8770
Epoch: 040/050 | Batch 150/469 | Cost: 12661.4824
Epoch: 040/050 | Batch 200/469 | Cost: 12516.9434
Epoch: 040/050 | Batch 250/469 | Cost: 11665.0059
Epoch: 040/050 | Batch 300/469 | Cost: 12424.7168
Epoch: 040/050 | Batch 350/469 | Cost: 12546.3516
Epoch: 040/050 | Batch 400/469 | Cost: 12085.0430
Epoch: 040/050 | Batch 450/469 | Cost: 12052.1777
Epoch: 041/050 | Batch 000/469 | Cost: 12553.8594
Epoch: 041/050 | Batch 050/469 | Cost: 12719.8916
Epoch: 041/050 | Batch 100/469 | Cost: 12318.2598
Epoch: 041/050 | Batch 150/469 | Cost: 12868.4424
Epoch: 041/050 | Batch 200/469 | Cost: 12110.4648
Epoch: 041/050 | Batch 250/469 | Cost: 12877.4014
Epoch: 041/050 | Batch 300/469 | Cost: 12044.2422
Epoch: 041/050 | Batch 350/469 | Cost: 12094.7090
Epoch: 041/050 | Batch 400/469 | Cost: 12124.3301
Epoch: 041/050 | Batch 450/469 | Cost: 12671.7217
Epoch: 042/050 | Batch 000/469 | Cost: 12054.0957
Epoch: 042/050 | Batch 050/469 | Cost: 12345.2227
Epoch: 042/050 | Batch 100/469 | Cost: 12810.0957
Epoch: 042/050 | Batch 150/469 | Cost: 11998.7207
Epoch: 042/050 | Batch 200/469 | Cost: 12693.5879
Epoch: 042/050 | Batch 250/469 | Cost: 11996.5615
Epoch: 042/050 | Batch 300/469 | Cost: 12084.2832
Epoch: 042/050 | Batch 350/469 | Cost: 12159.6025
Epoch: 042/050 | Batch 400/469 | Cost: 12514.6943
Epoch: 042/050 | Batch 450/469 | Cost: 12273.8809
Epoch: 043/050 | Batch 000/469 | Cost: 12472.9395
Epoch: 043/050 | Batch 050/469 | Cost: 12462.2734
Epoch: 043/050 | Batch 100/469 | Cost: 12303.0898
Epoch: 043/050 | Batch 150/469 | Cost: 12641.2676
Epoch: 043/050 | Batch 200/469 | Cost: 11870.0820
Epoch: 043/050 | Batch 250/469 | Cost: 12087.6504
Epoch: 043/050 | Batch 300/469 | Cost: 12615.6992
Epoch: 043/050 | Batch 350/469 | Cost: 12327.5391
Epoch: 043/050 | Batch 400/469 | Cost: 12761.4795
Epoch: 043/050 | Batch 450/469 | Cost: 12429.0576
Epoch: 044/050 | Batch 000/469 | Cost: 12172.6055
Epoch: 044/050 | Batch 050/469 | Cost: 12338.0742
Epoch: 044/050 | Batch 100/469 | Cost: 12473.4297
Epoch: 044/050 | Batch 150/469 | Cost: 12260.2695
Epoch: 044/050 | Batch 200/469 | Cost: 12475.7871
Epoch: 044/050 | Batch 250/469 | Cost: 12570.5645
Epoch: 044/050 | Batch 300/469 | Cost: 12297.6982
Epoch: 044/050 | Batch 350/469 | Cost: 12525.9111
Epoch: 044/050 | Batch 400/469 | Cost: 12596.0791
Epoch: 044/050 | Batch 450/469 | Cost: 11957.3623
Epoch: 045/050 | Batch 000/469 | Cost: 12849.4238
Epoch: 045/050 | Batch 050/469 | Cost: 12080.3203
Epoch: 045/050 | Batch 100/469 | Cost: 12260.8994
Epoch: 045/050 | Batch 150/469 | Cost: 12638.3770
Epoch: 045/050 | Batch 200/469 | Cost: 12635.9248
Epoch: 045/050 | Batch 250/469 | Cost: 12265.3184
Epoch: 045/050 | Batch 300/469 | Cost: 12359.3242
Epoch: 045/050 | Batch 350/469 | Cost: 12409.3135
Epoch: 045/050 | Batch 400/469 | Cost: 12485.5879
Epoch: 045/050 | Batch 450/469 | Cost: 12399.2988
Epoch: 046/050 | Batch 000/469 | Cost: 12027.0762
Epoch: 046/050 | Batch 050/469 | Cost: 12070.3789
Epoch: 046/050 | Batch 100/469 | Cost: 12531.2441
Epoch: 046/050 | Batch 150/469 | Cost: 12265.9395
Epoch: 046/050 | Batch 200/469 | Cost: 12452.6680
Epoch: 046/050 | Batch 250/469 | Cost: 13118.8711
Epoch: 046/050 | Batch 300/469 | Cost: 12208.3818
Epoch: 046/050 | Batch 350/469 | Cost: 12624.9814
Epoch: 046/050 | Batch 400/469 | Cost: 12488.5791
Epoch: 046/050 | Batch 450/469 | Cost: 12633.9775
Epoch: 047/050 | Batch 000/469 | Cost: 12152.1914
Epoch: 047/050 | Batch 050/469 | Cost: 12525.3857
Epoch: 047/050 | Batch 100/469 | Cost: 12195.7227
Epoch: 047/050 | Batch 150/469 | Cost: 12642.2949
Epoch: 047/050 | Batch 200/469 | Cost: 12667.8174
Epoch: 047/050 | Batch 250/469 | Cost: 12729.5176
Epoch: 047/050 | Batch 300/469 | Cost: 12052.5898
Epoch: 047/050 | Batch 350/469 | Cost: 12097.2480
Epoch: 047/050 | Batch 400/469 | Cost: 12530.8574
Epoch: 047/050 | Batch 450/469 | Cost: 12496.5098
Epoch: 048/050 | Batch 000/469 | Cost: 12613.0137
Epoch: 048/050 | Batch 050/469 | Cost: 12692.5273
Epoch: 048/050 | Batch 100/469 | Cost: 12363.4863
Epoch: 048/050 | Batch 150/469 | Cost: 11625.2861
Epoch: 048/050 | Batch 200/469 | Cost: 12005.9697
Epoch: 048/050 | Batch 250/469 | Cost: 12227.3750
Epoch: 048/050 | Batch 300/469 | Cost: 12684.3359
Epoch: 048/050 | Batch 350/469 | Cost: 12430.2783
Epoch: 048/050 | Batch 400/469 | Cost: 12213.2578
Epoch: 048/050 | Batch 450/469 | Cost: 13208.1133
Epoch: 049/050 | Batch 000/469 | Cost: 12118.3057
Epoch: 049/050 | Batch 050/469 | Cost: 12340.2715
Epoch: 049/050 | Batch 100/469 | Cost: 12029.6094
Epoch: 049/050 | Batch 150/469 | Cost: 12366.4453
Epoch: 049/050 | Batch 200/469 | Cost: 12537.2998
Epoch: 049/050 | Batch 250/469 | Cost: 12324.0312
Epoch: 049/050 | Batch 300/469 | Cost: 12378.3457
Epoch: 049/050 | Batch 350/469 | Cost: 12218.6914
Epoch: 049/050 | Batch 400/469 | Cost: 12550.4785
Epoch: 049/050 | Batch 450/469 | Cost: 12444.4463
Epoch: 050/050 | Batch 000/469 | Cost: 12246.0020
Epoch: 050/050 | Batch 050/469 | Cost: 12554.1836
Epoch: 050/050 | Batch 100/469 | Cost: 12373.2930
Epoch: 050/050 | Batch 150/469 | Cost: 12895.8096
Epoch: 050/050 | Batch 200/469 | Cost: 12233.0605
Epoch: 050/050 | Batch 250/469 | Cost: 12621.2920
Epoch: 050/050 | Batch 300/469 | Cost: 12492.7812
Epoch: 050/050 | Batch 350/469 | Cost: 12525.1934
Epoch: 050/050 | Batch 400/469 | Cost: 13032.4062
Epoch: 050/050 | Batch 450/469 | Cost: 12618.7773

Evaluation

Reconstruction

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt

##########################
### VISUALIZATION
##########################

n_images = 15
image_width = 28

fig, axes = plt.subplots(nrows=2, ncols=n_images, 
                         sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images]

for i in range(n_images):
    for ax, img in zip(axes, [orig_images, decoded_images]):
        curr_img = img[i].detach().to(torch.device('cpu'))
        ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')

New random-conditional images

In [7]:
for i in range(10):

    ##########################
    ### RANDOM SAMPLE
    ##########################    
    
    labels = torch.tensor([i]*10).to(device)
    n_images = labels.size()[0]
    rand_features = torch.randn(n_images, num_latent).to(device)
    new_images = model.decoder(rand_features, labels)

    ##########################
    ### VISUALIZATION
    ##########################

    image_width = 28

    fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
    decoded_images = new_images[:n_images]

    print('Class Label %d' % i)

    for ax, img in zip(axes, decoded_images):
        curr_img = img.detach().to(torch.device('cpu'))
        ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')
        
    plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
In [8]:
%watermark -iv
numpy       1.15.4
torch       1.0.0