Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.6.8
IPython 7.2.0

torch 1.0.0
  • Runs on CPU or GPU (if available)

Model Zoo -- Conditional Variational Autoencoder

(with labels in reconstruction loss)

A simple conditional variational autoencoder that compresses 768-pixel MNIST images down to a 35-pixel latent vector representation.

This implementation concatenates the inputs with the class labels when computing the reconstruction loss as it is commonly done in non-convolutional conditional variational autoencoders. This leads to sightly better results compared to the implementation that does NOT concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cvae_no-out-concat.ipynb

Imports

In [2]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
##########################
### SETTINGS
##########################

# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)

# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128

# Architecture
num_classes = 10
num_features = 784
num_hidden_1 = 500
num_latent = 35


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Device: cuda:1
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [4]:
##########################
### MODEL
##########################


def to_onehot(labels, num_classes, device):

    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)

    return labels_onehot


class ConditionalVariationalAutoencoder(torch.nn.Module):

    def __init__(self, num_features, num_hidden_1, num_latent, num_classes):
        super(ConditionalVariationalAutoencoder, self).__init__()
        
        self.num_classes = num_classes
        
        ### ENCODER
        self.hidden_1 = torch.nn.Linear(num_features+num_classes, num_hidden_1)
        self.z_mean = torch.nn.Linear(num_hidden_1, num_latent)
        # in the original paper (Kingma & Welling 2015, we use
        # have a z_mean and z_var, but the problem is that
        # the z_var can be negative, which would cause issues
        # in the log later. Hence we assume that latent vector
        # has a z_mean and z_log_var component, and when we need
        # the regular variance or std_dev, we simply use 
        # an exponential function
        self.z_log_var = torch.nn.Linear(num_hidden_1, num_latent)
        
        
        ### DECODER
        self.linear_3 = torch.nn.Linear(num_latent+num_classes, num_hidden_1)
        self.linear_4 = torch.nn.Linear(num_hidden_1, num_features+num_classes)

    def reparameterize(self, z_mu, z_log_var):
        # Sample epsilon from standard normal distribution
        eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
        # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
        # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z
        
    def encoder(self, features, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        x = torch.cat((features, onehot_targets), dim=1)

        ### ENCODER
        x = self.hidden_1(x)
        x = F.leaky_relu(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)
        return z_mean, z_log_var, encoded
    
    def decoder(self, encoded, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        encoded = torch.cat((encoded, onehot_targets), dim=1)        
        
        ### DECODER
        x = self.linear_3(encoded)
        x = F.leaky_relu(x)
        x = self.linear_4(x)
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, features, targets):
        
        z_mean, z_log_var, encoded = self.encoder(features, targets)
        decoded = self.decoder(encoded, targets)
        
        return z_mean, z_log_var, encoded, decoded

    
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
                                          num_hidden_1,
                                          num_latent,
                                          num_classes)
model = model.to(device)
    

##########################
### COST AND OPTIMIZER
##########################

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Training

In [5]:
start_time = time.time()
for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.view(-1, 28*28).to(device)
        targets = targets.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features, targets)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        
        # add condition
        x_con = torch.cat((features, to_onehot(targets, num_classes, device)), dim=1)
        
        pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        optimizer.zero_grad()
        cost.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader)//batch_size, cost))
        
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/003 | Cost: 71038.5859
Epoch: 001/050 | Batch 050/003 | Cost: 27687.3066
Epoch: 001/050 | Batch 100/003 | Cost: 22471.4355
Epoch: 001/050 | Batch 150/003 | Cost: 20541.7598
Epoch: 001/050 | Batch 200/003 | Cost: 19670.0254
Epoch: 001/050 | Batch 250/003 | Cost: 18063.1172
Epoch: 001/050 | Batch 300/003 | Cost: 18197.8105
Epoch: 001/050 | Batch 350/003 | Cost: 17140.0234
Epoch: 001/050 | Batch 400/003 | Cost: 17088.2832
Epoch: 001/050 | Batch 450/003 | Cost: 16489.5078
Time elapsed: 0.12 min
Epoch: 002/050 | Batch 000/003 | Cost: 16543.6680
Epoch: 002/050 | Batch 050/003 | Cost: 15903.9375
Epoch: 002/050 | Batch 100/003 | Cost: 15624.4824
Epoch: 002/050 | Batch 150/003 | Cost: 15635.6016
Epoch: 002/050 | Batch 200/003 | Cost: 15245.4121
Epoch: 002/050 | Batch 250/003 | Cost: 15592.2266
Epoch: 002/050 | Batch 300/003 | Cost: 14563.4980
Epoch: 002/050 | Batch 350/003 | Cost: 15090.5811
Epoch: 002/050 | Batch 400/003 | Cost: 14832.8682
Epoch: 002/050 | Batch 450/003 | Cost: 15250.2480
Time elapsed: 0.24 min
Epoch: 003/050 | Batch 000/003 | Cost: 14894.0859
Epoch: 003/050 | Batch 050/003 | Cost: 14885.7168
Epoch: 003/050 | Batch 100/003 | Cost: 14998.2500
Epoch: 003/050 | Batch 150/003 | Cost: 14456.1309
Epoch: 003/050 | Batch 200/003 | Cost: 14737.8740
Epoch: 003/050 | Batch 250/003 | Cost: 14265.4639
Epoch: 003/050 | Batch 300/003 | Cost: 13826.1016
Epoch: 003/050 | Batch 350/003 | Cost: 14193.5059
Epoch: 003/050 | Batch 400/003 | Cost: 14308.9688
Epoch: 003/050 | Batch 450/003 | Cost: 13530.7305
Time elapsed: 0.37 min
Epoch: 004/050 | Batch 000/003 | Cost: 14149.7422
Epoch: 004/050 | Batch 050/003 | Cost: 14044.5645
Epoch: 004/050 | Batch 100/003 | Cost: 13909.8926
Epoch: 004/050 | Batch 150/003 | Cost: 14146.3789
Epoch: 004/050 | Batch 200/003 | Cost: 13917.0820
Epoch: 004/050 | Batch 250/003 | Cost: 13873.8262
Epoch: 004/050 | Batch 300/003 | Cost: 13358.3555
Epoch: 004/050 | Batch 350/003 | Cost: 13450.4238
Epoch: 004/050 | Batch 400/003 | Cost: 13677.4062
Epoch: 004/050 | Batch 450/003 | Cost: 14247.6426
Time elapsed: 0.49 min
Epoch: 005/050 | Batch 000/003 | Cost: 13364.3730
Epoch: 005/050 | Batch 050/003 | Cost: 13535.6279
Epoch: 005/050 | Batch 100/003 | Cost: 13827.6201
Epoch: 005/050 | Batch 150/003 | Cost: 13421.4111
Epoch: 005/050 | Batch 200/003 | Cost: 13467.9238
Epoch: 005/050 | Batch 250/003 | Cost: 13812.4131
Epoch: 005/050 | Batch 300/003 | Cost: 13457.0234
Epoch: 005/050 | Batch 350/003 | Cost: 14082.8926
Epoch: 005/050 | Batch 400/003 | Cost: 13224.9209
Epoch: 005/050 | Batch 450/003 | Cost: 13151.8311
Time elapsed: 0.61 min
Epoch: 006/050 | Batch 000/003 | Cost: 13850.8398
Epoch: 006/050 | Batch 050/003 | Cost: 13386.6289
Epoch: 006/050 | Batch 100/003 | Cost: 13522.1650
Epoch: 006/050 | Batch 150/003 | Cost: 12865.0898
Epoch: 006/050 | Batch 200/003 | Cost: 13949.3652
Epoch: 006/050 | Batch 250/003 | Cost: 12964.2607
Epoch: 006/050 | Batch 300/003 | Cost: 13692.9707
Epoch: 006/050 | Batch 350/003 | Cost: 13319.4219
Epoch: 006/050 | Batch 400/003 | Cost: 13278.6582
Epoch: 006/050 | Batch 450/003 | Cost: 13224.0107
Time elapsed: 0.73 min
Epoch: 007/050 | Batch 000/003 | Cost: 13686.9590
Epoch: 007/050 | Batch 050/003 | Cost: 13188.3828
Epoch: 007/050 | Batch 100/003 | Cost: 13173.3457
Epoch: 007/050 | Batch 150/003 | Cost: 13309.3906
Epoch: 007/050 | Batch 200/003 | Cost: 13174.3359
Epoch: 007/050 | Batch 250/003 | Cost: 13169.7100
Epoch: 007/050 | Batch 300/003 | Cost: 13255.1074
Epoch: 007/050 | Batch 350/003 | Cost: 13153.2480
Epoch: 007/050 | Batch 400/003 | Cost: 13683.6426
Epoch: 007/050 | Batch 450/003 | Cost: 13518.1055
Time elapsed: 0.85 min
Epoch: 008/050 | Batch 000/003 | Cost: 13531.5840
Epoch: 008/050 | Batch 050/003 | Cost: 12898.9746
Epoch: 008/050 | Batch 100/003 | Cost: 13434.8193
Epoch: 008/050 | Batch 150/003 | Cost: 13045.6172
Epoch: 008/050 | Batch 200/003 | Cost: 13175.8867
Epoch: 008/050 | Batch 250/003 | Cost: 12952.2139
Epoch: 008/050 | Batch 300/003 | Cost: 13100.8184
Epoch: 008/050 | Batch 350/003 | Cost: 13419.0410
Epoch: 008/050 | Batch 400/003 | Cost: 13768.5234
Epoch: 008/050 | Batch 450/003 | Cost: 13300.7402
Time elapsed: 0.97 min
Epoch: 009/050 | Batch 000/003 | Cost: 13035.2109
Epoch: 009/050 | Batch 050/003 | Cost: 13260.9902
Epoch: 009/050 | Batch 100/003 | Cost: 13446.3652
Epoch: 009/050 | Batch 150/003 | Cost: 13241.1523
Epoch: 009/050 | Batch 200/003 | Cost: 13139.6904
Epoch: 009/050 | Batch 250/003 | Cost: 13418.0098
Epoch: 009/050 | Batch 300/003 | Cost: 12743.2139
Epoch: 009/050 | Batch 350/003 | Cost: 13347.6465
Epoch: 009/050 | Batch 400/003 | Cost: 13353.1543
Epoch: 009/050 | Batch 450/003 | Cost: 13161.7549
Time elapsed: 1.09 min
Epoch: 010/050 | Batch 000/003 | Cost: 13043.7656
Epoch: 010/050 | Batch 050/003 | Cost: 13368.1816
Epoch: 010/050 | Batch 100/003 | Cost: 12742.3281
Epoch: 010/050 | Batch 150/003 | Cost: 13022.2832
Epoch: 010/050 | Batch 200/003 | Cost: 13076.5967
Epoch: 010/050 | Batch 250/003 | Cost: 12765.7480
Epoch: 010/050 | Batch 300/003 | Cost: 12605.2461
Epoch: 010/050 | Batch 350/003 | Cost: 12561.1895
Epoch: 010/050 | Batch 400/003 | Cost: 13314.3623
Epoch: 010/050 | Batch 450/003 | Cost: 13048.6670
Time elapsed: 1.21 min
Epoch: 011/050 | Batch 000/003 | Cost: 12769.8984
Epoch: 011/050 | Batch 050/003 | Cost: 13002.4824
Epoch: 011/050 | Batch 100/003 | Cost: 12989.9756
Epoch: 011/050 | Batch 150/003 | Cost: 12345.3877
Epoch: 011/050 | Batch 200/003 | Cost: 12599.8975
Epoch: 011/050 | Batch 250/003 | Cost: 13375.6543
Epoch: 011/050 | Batch 300/003 | Cost: 12829.1621
Epoch: 011/050 | Batch 350/003 | Cost: 12245.9316
Epoch: 011/050 | Batch 400/003 | Cost: 12745.7324
Epoch: 011/050 | Batch 450/003 | Cost: 12833.6787
Time elapsed: 1.33 min
Epoch: 012/050 | Batch 000/003 | Cost: 13168.7441
Epoch: 012/050 | Batch 050/003 | Cost: 12641.6289
Epoch: 012/050 | Batch 100/003 | Cost: 12435.3809
Epoch: 012/050 | Batch 150/003 | Cost: 12635.2422
Epoch: 012/050 | Batch 200/003 | Cost: 12657.0459
Epoch: 012/050 | Batch 250/003 | Cost: 12965.5098
Epoch: 012/050 | Batch 300/003 | Cost: 13022.1289
Epoch: 012/050 | Batch 350/003 | Cost: 13023.3535
Epoch: 012/050 | Batch 400/003 | Cost: 12723.5039
Epoch: 012/050 | Batch 450/003 | Cost: 12593.6680
Time elapsed: 1.45 min
Epoch: 013/050 | Batch 000/003 | Cost: 13062.8193
Epoch: 013/050 | Batch 050/003 | Cost: 12942.4668
Epoch: 013/050 | Batch 100/003 | Cost: 13116.7461
Epoch: 013/050 | Batch 150/003 | Cost: 12458.0615
Epoch: 013/050 | Batch 200/003 | Cost: 12586.8535
Epoch: 013/050 | Batch 250/003 | Cost: 12546.0703
Epoch: 013/050 | Batch 300/003 | Cost: 12817.3428
Epoch: 013/050 | Batch 350/003 | Cost: 13081.0586
Epoch: 013/050 | Batch 400/003 | Cost: 13006.0645
Epoch: 013/050 | Batch 450/003 | Cost: 12553.9902
Time elapsed: 1.57 min
Epoch: 014/050 | Batch 000/003 | Cost: 12912.1406
Epoch: 014/050 | Batch 050/003 | Cost: 12546.7812
Epoch: 014/050 | Batch 100/003 | Cost: 12928.7334
Epoch: 014/050 | Batch 150/003 | Cost: 12532.9268
Epoch: 014/050 | Batch 200/003 | Cost: 12385.5430
Epoch: 014/050 | Batch 250/003 | Cost: 12737.9053
Epoch: 014/050 | Batch 300/003 | Cost: 12179.8184
Epoch: 014/050 | Batch 350/003 | Cost: 12855.8301
Epoch: 014/050 | Batch 400/003 | Cost: 12984.9512
Epoch: 014/050 | Batch 450/003 | Cost: 12996.2832
Time elapsed: 1.69 min
Epoch: 015/050 | Batch 000/003 | Cost: 12615.2100
Epoch: 015/050 | Batch 050/003 | Cost: 13215.7402
Epoch: 015/050 | Batch 100/003 | Cost: 12532.8447
Epoch: 015/050 | Batch 150/003 | Cost: 12863.5928
Epoch: 015/050 | Batch 200/003 | Cost: 12683.9209
Epoch: 015/050 | Batch 250/003 | Cost: 12564.1309
Epoch: 015/050 | Batch 300/003 | Cost: 12535.2500
Epoch: 015/050 | Batch 350/003 | Cost: 12853.1445
Epoch: 015/050 | Batch 400/003 | Cost: 12600.9902
Epoch: 015/050 | Batch 450/003 | Cost: 12867.0781
Time elapsed: 1.81 min
Epoch: 016/050 | Batch 000/003 | Cost: 12429.7402
Epoch: 016/050 | Batch 050/003 | Cost: 13151.1377
Epoch: 016/050 | Batch 100/003 | Cost: 12885.5371
Epoch: 016/050 | Batch 150/003 | Cost: 12601.3242
Epoch: 016/050 | Batch 200/003 | Cost: 13196.4834
Epoch: 016/050 | Batch 250/003 | Cost: 12570.6836
Epoch: 016/050 | Batch 300/003 | Cost: 12942.7861
Epoch: 016/050 | Batch 350/003 | Cost: 12389.7363
Epoch: 016/050 | Batch 400/003 | Cost: 12576.1445
Epoch: 016/050 | Batch 450/003 | Cost: 12140.7900
Time elapsed: 1.93 min
Epoch: 017/050 | Batch 000/003 | Cost: 12978.2227
Epoch: 017/050 | Batch 050/003 | Cost: 12447.1230
Epoch: 017/050 | Batch 100/003 | Cost: 12980.0459
Epoch: 017/050 | Batch 150/003 | Cost: 12901.1045
Epoch: 017/050 | Batch 200/003 | Cost: 12381.2070
Epoch: 017/050 | Batch 250/003 | Cost: 12956.8857
Epoch: 017/050 | Batch 300/003 | Cost: 12341.9512
Epoch: 017/050 | Batch 350/003 | Cost: 12692.6270
Epoch: 017/050 | Batch 400/003 | Cost: 12316.9727
Epoch: 017/050 | Batch 450/003 | Cost: 12857.4844
Time elapsed: 2.05 min
Epoch: 018/050 | Batch 000/003 | Cost: 12531.7217
Epoch: 018/050 | Batch 050/003 | Cost: 12500.7012
Epoch: 018/050 | Batch 100/003 | Cost: 12458.2969
Epoch: 018/050 | Batch 150/003 | Cost: 12838.6758
Epoch: 018/050 | Batch 200/003 | Cost: 12678.9072
Epoch: 018/050 | Batch 250/003 | Cost: 12199.8320
Epoch: 018/050 | Batch 300/003 | Cost: 12352.8457
Epoch: 018/050 | Batch 350/003 | Cost: 12980.6797
Epoch: 018/050 | Batch 400/003 | Cost: 11996.0254
Epoch: 018/050 | Batch 450/003 | Cost: 12993.7158
Time elapsed: 2.17 min
Epoch: 019/050 | Batch 000/003 | Cost: 12559.5215
Epoch: 019/050 | Batch 050/003 | Cost: 12757.8662
Epoch: 019/050 | Batch 100/003 | Cost: 13118.6836
Epoch: 019/050 | Batch 150/003 | Cost: 13059.1777
Epoch: 019/050 | Batch 200/003 | Cost: 13048.2031
Epoch: 019/050 | Batch 250/003 | Cost: 12460.9688
Epoch: 019/050 | Batch 300/003 | Cost: 12757.1748
Epoch: 019/050 | Batch 350/003 | Cost: 12035.1006
Epoch: 019/050 | Batch 400/003 | Cost: 12802.9883
Epoch: 019/050 | Batch 450/003 | Cost: 12858.0488
Time elapsed: 2.29 min
Epoch: 020/050 | Batch 000/003 | Cost: 12934.9336
Epoch: 020/050 | Batch 050/003 | Cost: 12974.0488
Epoch: 020/050 | Batch 100/003 | Cost: 12512.7949
Epoch: 020/050 | Batch 150/003 | Cost: 12980.7275
Epoch: 020/050 | Batch 200/003 | Cost: 12930.8789
Epoch: 020/050 | Batch 250/003 | Cost: 12721.5723
Epoch: 020/050 | Batch 300/003 | Cost: 12933.5918
Epoch: 020/050 | Batch 350/003 | Cost: 12456.5488
Epoch: 020/050 | Batch 400/003 | Cost: 12849.0654
Epoch: 020/050 | Batch 450/003 | Cost: 12522.7480
Time elapsed: 2.41 min
Epoch: 021/050 | Batch 000/003 | Cost: 12122.0742
Epoch: 021/050 | Batch 050/003 | Cost: 13007.4697
Epoch: 021/050 | Batch 100/003 | Cost: 12690.0635
Epoch: 021/050 | Batch 150/003 | Cost: 12353.1621
Epoch: 021/050 | Batch 200/003 | Cost: 13007.6445
Epoch: 021/050 | Batch 250/003 | Cost: 12202.5176
Epoch: 021/050 | Batch 300/003 | Cost: 12154.6660
Epoch: 021/050 | Batch 350/003 | Cost: 12723.7158
Epoch: 021/050 | Batch 400/003 | Cost: 12902.6582
Epoch: 021/050 | Batch 450/003 | Cost: 12786.1484
Time elapsed: 2.53 min
Epoch: 022/050 | Batch 000/003 | Cost: 12648.4844
Epoch: 022/050 | Batch 050/003 | Cost: 12816.2891
Epoch: 022/050 | Batch 100/003 | Cost: 12544.2412
Epoch: 022/050 | Batch 150/003 | Cost: 12651.0215
Epoch: 022/050 | Batch 200/003 | Cost: 12831.6562
Epoch: 022/050 | Batch 250/003 | Cost: 12590.9326
Epoch: 022/050 | Batch 300/003 | Cost: 12396.2373
Epoch: 022/050 | Batch 350/003 | Cost: 12948.6094
Epoch: 022/050 | Batch 400/003 | Cost: 12553.6816
Epoch: 022/050 | Batch 450/003 | Cost: 12309.2637
Time elapsed: 2.65 min
Epoch: 023/050 | Batch 000/003 | Cost: 12857.9453
Epoch: 023/050 | Batch 050/003 | Cost: 12910.1377
Epoch: 023/050 | Batch 100/003 | Cost: 12449.3242
Epoch: 023/050 | Batch 150/003 | Cost: 12278.5312
Epoch: 023/050 | Batch 200/003 | Cost: 12971.6885
Epoch: 023/050 | Batch 250/003 | Cost: 13084.1699
Epoch: 023/050 | Batch 300/003 | Cost: 12463.8232
Epoch: 023/050 | Batch 350/003 | Cost: 12589.3398
Epoch: 023/050 | Batch 400/003 | Cost: 12732.2168
Epoch: 023/050 | Batch 450/003 | Cost: 12196.4492
Time elapsed: 2.76 min
Epoch: 024/050 | Batch 000/003 | Cost: 12342.8789
Epoch: 024/050 | Batch 050/003 | Cost: 12255.4883
Epoch: 024/050 | Batch 100/003 | Cost: 12158.4902
Epoch: 024/050 | Batch 150/003 | Cost: 12731.5742
Epoch: 024/050 | Batch 200/003 | Cost: 12789.7168
Epoch: 024/050 | Batch 250/003 | Cost: 12213.1104
Epoch: 024/050 | Batch 300/003 | Cost: 12613.8281
Epoch: 024/050 | Batch 350/003 | Cost: 12530.3096
Epoch: 024/050 | Batch 400/003 | Cost: 12475.6035
Epoch: 024/050 | Batch 450/003 | Cost: 12182.7178
Time elapsed: 2.88 min
Epoch: 025/050 | Batch 000/003 | Cost: 12929.5430
Epoch: 025/050 | Batch 050/003 | Cost: 12472.9180
Epoch: 025/050 | Batch 100/003 | Cost: 11870.2754
Epoch: 025/050 | Batch 150/003 | Cost: 12619.1641
Epoch: 025/050 | Batch 200/003 | Cost: 12285.2344
Epoch: 025/050 | Batch 250/003 | Cost: 12557.6367
Epoch: 025/050 | Batch 300/003 | Cost: 12409.3574
Epoch: 025/050 | Batch 350/003 | Cost: 12889.7910
Epoch: 025/050 | Batch 400/003 | Cost: 12708.5605
Epoch: 025/050 | Batch 450/003 | Cost: 12577.1514
Time elapsed: 3.00 min
Epoch: 026/050 | Batch 000/003 | Cost: 12301.7139
Epoch: 026/050 | Batch 050/003 | Cost: 12692.7188
Epoch: 026/050 | Batch 100/003 | Cost: 12601.7607
Epoch: 026/050 | Batch 150/003 | Cost: 12460.5254
Epoch: 026/050 | Batch 200/003 | Cost: 12769.4287
Epoch: 026/050 | Batch 250/003 | Cost: 12428.0000
Epoch: 026/050 | Batch 300/003 | Cost: 12987.5449
Epoch: 026/050 | Batch 350/003 | Cost: 12646.5908
Epoch: 026/050 | Batch 400/003 | Cost: 12335.6738
Epoch: 026/050 | Batch 450/003 | Cost: 12613.5449
Time elapsed: 3.12 min
Epoch: 027/050 | Batch 000/003 | Cost: 12647.8809
Epoch: 027/050 | Batch 050/003 | Cost: 12970.0127
Epoch: 027/050 | Batch 100/003 | Cost: 12870.9219
Epoch: 027/050 | Batch 150/003 | Cost: 12435.1553
Epoch: 027/050 | Batch 200/003 | Cost: 12810.8418
Epoch: 027/050 | Batch 250/003 | Cost: 12727.6777
Epoch: 027/050 | Batch 300/003 | Cost: 12762.6055
Epoch: 027/050 | Batch 350/003 | Cost: 12970.4414
Epoch: 027/050 | Batch 400/003 | Cost: 12745.8652
Epoch: 027/050 | Batch 450/003 | Cost: 12442.3232
Time elapsed: 3.24 min
Epoch: 028/050 | Batch 000/003 | Cost: 12199.5078
Epoch: 028/050 | Batch 050/003 | Cost: 12707.5625
Epoch: 028/050 | Batch 100/003 | Cost: 12289.9277
Epoch: 028/050 | Batch 150/003 | Cost: 12375.3242
Epoch: 028/050 | Batch 200/003 | Cost: 12023.3887
Epoch: 028/050 | Batch 250/003 | Cost: 12776.7168
Epoch: 028/050 | Batch 300/003 | Cost: 12680.4668
Epoch: 028/050 | Batch 350/003 | Cost: 12701.3281
Epoch: 028/050 | Batch 400/003 | Cost: 12561.7227
Epoch: 028/050 | Batch 450/003 | Cost: 12763.8447
Time elapsed: 3.36 min
Epoch: 029/050 | Batch 000/003 | Cost: 12721.6777
Epoch: 029/050 | Batch 050/003 | Cost: 12443.0645
Epoch: 029/050 | Batch 100/003 | Cost: 12057.7822
Epoch: 029/050 | Batch 150/003 | Cost: 12504.2529
Epoch: 029/050 | Batch 200/003 | Cost: 12310.3965
Epoch: 029/050 | Batch 250/003 | Cost: 13202.6211
Epoch: 029/050 | Batch 300/003 | Cost: 12117.8008
Epoch: 029/050 | Batch 350/003 | Cost: 12538.9092
Epoch: 029/050 | Batch 400/003 | Cost: 12451.9180
Epoch: 029/050 | Batch 450/003 | Cost: 12649.5537
Time elapsed: 3.48 min
Epoch: 030/050 | Batch 000/003 | Cost: 13177.7520
Epoch: 030/050 | Batch 050/003 | Cost: 12634.2002
Epoch: 030/050 | Batch 100/003 | Cost: 12582.4863
Epoch: 030/050 | Batch 150/003 | Cost: 12516.8877
Epoch: 030/050 | Batch 200/003 | Cost: 12460.7139
Epoch: 030/050 | Batch 250/003 | Cost: 12385.2090
Epoch: 030/050 | Batch 300/003 | Cost: 12847.1113
Epoch: 030/050 | Batch 350/003 | Cost: 12123.1543
Epoch: 030/050 | Batch 400/003 | Cost: 12427.7227
Epoch: 030/050 | Batch 450/003 | Cost: 12904.1279
Time elapsed: 3.60 min
Epoch: 031/050 | Batch 000/003 | Cost: 12551.2178
Epoch: 031/050 | Batch 050/003 | Cost: 13006.6875
Epoch: 031/050 | Batch 100/003 | Cost: 12672.4551
Epoch: 031/050 | Batch 150/003 | Cost: 12577.4131
Epoch: 031/050 | Batch 200/003 | Cost: 12595.9150
Epoch: 031/050 | Batch 250/003 | Cost: 12294.5635
Epoch: 031/050 | Batch 300/003 | Cost: 12491.6406
Epoch: 031/050 | Batch 350/003 | Cost: 12726.5947
Epoch: 031/050 | Batch 400/003 | Cost: 12449.7246
Epoch: 031/050 | Batch 450/003 | Cost: 12771.0234
Time elapsed: 3.72 min
Epoch: 032/050 | Batch 000/003 | Cost: 12567.2988
Epoch: 032/050 | Batch 050/003 | Cost: 11960.7676
Epoch: 032/050 | Batch 100/003 | Cost: 12276.4648
Epoch: 032/050 | Batch 150/003 | Cost: 13205.2402
Epoch: 032/050 | Batch 200/003 | Cost: 12931.1514
Epoch: 032/050 | Batch 250/003 | Cost: 12975.4473
Epoch: 032/050 | Batch 300/003 | Cost: 12364.8164
Epoch: 032/050 | Batch 350/003 | Cost: 13167.5020
Epoch: 032/050 | Batch 400/003 | Cost: 12439.9355
Epoch: 032/050 | Batch 450/003 | Cost: 12526.5000
Time elapsed: 3.84 min
Epoch: 033/050 | Batch 000/003 | Cost: 12447.2412
Epoch: 033/050 | Batch 050/003 | Cost: 12603.1895
Epoch: 033/050 | Batch 100/003 | Cost: 11774.2539
Epoch: 033/050 | Batch 150/003 | Cost: 12859.1406
Epoch: 033/050 | Batch 200/003 | Cost: 12321.5195
Epoch: 033/050 | Batch 250/003 | Cost: 12458.5352
Epoch: 033/050 | Batch 300/003 | Cost: 12871.9336
Epoch: 033/050 | Batch 350/003 | Cost: 12373.0039
Epoch: 033/050 | Batch 400/003 | Cost: 12674.4531
Epoch: 033/050 | Batch 450/003 | Cost: 12425.8633
Time elapsed: 3.96 min
Epoch: 034/050 | Batch 000/003 | Cost: 12650.1582
Epoch: 034/050 | Batch 050/003 | Cost: 12587.3613
Epoch: 034/050 | Batch 100/003 | Cost: 12675.3105
Epoch: 034/050 | Batch 150/003 | Cost: 12833.5391
Epoch: 034/050 | Batch 200/003 | Cost: 12441.7305
Epoch: 034/050 | Batch 250/003 | Cost: 12733.2959
Epoch: 034/050 | Batch 300/003 | Cost: 12329.9219
Epoch: 034/050 | Batch 350/003 | Cost: 12802.7354
Epoch: 034/050 | Batch 400/003 | Cost: 11619.9902
Epoch: 034/050 | Batch 450/003 | Cost: 12474.8330
Time elapsed: 4.08 min
Epoch: 035/050 | Batch 000/003 | Cost: 12535.9199
Epoch: 035/050 | Batch 050/003 | Cost: 12450.3691
Epoch: 035/050 | Batch 100/003 | Cost: 12693.5137
Epoch: 035/050 | Batch 150/003 | Cost: 12554.5850
Epoch: 035/050 | Batch 200/003 | Cost: 12123.2129
Epoch: 035/050 | Batch 250/003 | Cost: 12241.1787
Epoch: 035/050 | Batch 300/003 | Cost: 12227.1240
Epoch: 035/050 | Batch 350/003 | Cost: 12536.8457
Epoch: 035/050 | Batch 400/003 | Cost: 12318.9238
Epoch: 035/050 | Batch 450/003 | Cost: 12698.3252
Time elapsed: 4.20 min
Epoch: 036/050 | Batch 000/003 | Cost: 12412.0586
Epoch: 036/050 | Batch 050/003 | Cost: 12438.2656
Epoch: 036/050 | Batch 100/003 | Cost: 12465.1973
Epoch: 036/050 | Batch 150/003 | Cost: 12164.7285
Epoch: 036/050 | Batch 200/003 | Cost: 12492.0488
Epoch: 036/050 | Batch 250/003 | Cost: 12558.4023
Epoch: 036/050 | Batch 300/003 | Cost: 12648.7227
Epoch: 036/050 | Batch 350/003 | Cost: 12408.2832
Epoch: 036/050 | Batch 400/003 | Cost: 12169.1201
Epoch: 036/050 | Batch 450/003 | Cost: 12748.4707
Time elapsed: 4.32 min
Epoch: 037/050 | Batch 000/003 | Cost: 12361.0439
Epoch: 037/050 | Batch 050/003 | Cost: 12553.2051
Epoch: 037/050 | Batch 100/003 | Cost: 12548.9307
Epoch: 037/050 | Batch 150/003 | Cost: 12064.2344
Epoch: 037/050 | Batch 200/003 | Cost: 12219.1270
Epoch: 037/050 | Batch 250/003 | Cost: 12353.5186
Epoch: 037/050 | Batch 300/003 | Cost: 12224.3682
Epoch: 037/050 | Batch 350/003 | Cost: 12668.8193
Epoch: 037/050 | Batch 400/003 | Cost: 12274.8691
Epoch: 037/050 | Batch 450/003 | Cost: 12393.6182
Time elapsed: 4.44 min
Epoch: 038/050 | Batch 000/003 | Cost: 12669.8281
Epoch: 038/050 | Batch 050/003 | Cost: 12271.3066
Epoch: 038/050 | Batch 100/003 | Cost: 12387.7930
Epoch: 038/050 | Batch 150/003 | Cost: 11961.1123
Epoch: 038/050 | Batch 200/003 | Cost: 11976.5498
Epoch: 038/050 | Batch 250/003 | Cost: 12573.2324
Epoch: 038/050 | Batch 300/003 | Cost: 12608.4854
Epoch: 038/050 | Batch 350/003 | Cost: 12112.7617
Epoch: 038/050 | Batch 400/003 | Cost: 12275.8418
Epoch: 038/050 | Batch 450/003 | Cost: 12417.7549
Time elapsed: 4.56 min
Epoch: 039/050 | Batch 000/003 | Cost: 12874.1016
Epoch: 039/050 | Batch 050/003 | Cost: 12372.2070
Epoch: 039/050 | Batch 100/003 | Cost: 12446.2695
Epoch: 039/050 | Batch 150/003 | Cost: 13140.2764
Epoch: 039/050 | Batch 200/003 | Cost: 12825.3037
Epoch: 039/050 | Batch 250/003 | Cost: 12165.7451
Epoch: 039/050 | Batch 300/003 | Cost: 12430.3340
Epoch: 039/050 | Batch 350/003 | Cost: 12702.8613
Epoch: 039/050 | Batch 400/003 | Cost: 12374.5752
Epoch: 039/050 | Batch 450/003 | Cost: 12414.6475
Time elapsed: 4.68 min
Epoch: 040/050 | Batch 000/003 | Cost: 12147.0693
Epoch: 040/050 | Batch 050/003 | Cost: 12907.0684
Epoch: 040/050 | Batch 100/003 | Cost: 11664.0801
Epoch: 040/050 | Batch 150/003 | Cost: 12443.9512
Epoch: 040/050 | Batch 200/003 | Cost: 12112.6250
Epoch: 040/050 | Batch 250/003 | Cost: 12146.6133
Epoch: 040/050 | Batch 300/003 | Cost: 12050.8281
Epoch: 040/050 | Batch 350/003 | Cost: 12762.5020
Epoch: 040/050 | Batch 400/003 | Cost: 12517.0771
Epoch: 040/050 | Batch 450/003 | Cost: 12191.8916
Time elapsed: 4.80 min
Epoch: 041/050 | Batch 000/003 | Cost: 12098.7090
Epoch: 041/050 | Batch 050/003 | Cost: 12562.2539
Epoch: 041/050 | Batch 100/003 | Cost: 12609.0088
Epoch: 041/050 | Batch 150/003 | Cost: 12270.9854
Epoch: 041/050 | Batch 200/003 | Cost: 12757.7578
Epoch: 041/050 | Batch 250/003 | Cost: 12277.8584
Epoch: 041/050 | Batch 300/003 | Cost: 12219.4121
Epoch: 041/050 | Batch 350/003 | Cost: 12215.5977
Epoch: 041/050 | Batch 400/003 | Cost: 12476.9668
Epoch: 041/050 | Batch 450/003 | Cost: 12220.5449
Time elapsed: 4.92 min
Epoch: 042/050 | Batch 000/003 | Cost: 12391.3916
Epoch: 042/050 | Batch 050/003 | Cost: 12526.2002
Epoch: 042/050 | Batch 100/003 | Cost: 12929.2305
Epoch: 042/050 | Batch 150/003 | Cost: 12575.8652
Epoch: 042/050 | Batch 200/003 | Cost: 12588.2656
Epoch: 042/050 | Batch 250/003 | Cost: 12191.2520
Epoch: 042/050 | Batch 300/003 | Cost: 12416.6172
Epoch: 042/050 | Batch 350/003 | Cost: 12398.3096
Epoch: 042/050 | Batch 400/003 | Cost: 12093.6777
Epoch: 042/050 | Batch 450/003 | Cost: 12391.6504
Time elapsed: 5.04 min
Epoch: 043/050 | Batch 000/003 | Cost: 12748.3145
Epoch: 043/050 | Batch 050/003 | Cost: 12076.2344
Epoch: 043/050 | Batch 100/003 | Cost: 11953.4248
Epoch: 043/050 | Batch 150/003 | Cost: 12420.9707
Epoch: 043/050 | Batch 200/003 | Cost: 12428.2041
Epoch: 043/050 | Batch 250/003 | Cost: 12447.2324
Epoch: 043/050 | Batch 300/003 | Cost: 12404.4004
Epoch: 043/050 | Batch 350/003 | Cost: 13077.8926
Epoch: 043/050 | Batch 400/003 | Cost: 13071.2891
Epoch: 043/050 | Batch 450/003 | Cost: 12398.8311
Time elapsed: 5.16 min
Epoch: 044/050 | Batch 000/003 | Cost: 12913.1631
Epoch: 044/050 | Batch 050/003 | Cost: 12169.1523
Epoch: 044/050 | Batch 100/003 | Cost: 11856.3672
Epoch: 044/050 | Batch 150/003 | Cost: 12280.6045
Epoch: 044/050 | Batch 200/003 | Cost: 12343.7998
Epoch: 044/050 | Batch 250/003 | Cost: 12746.3164
Epoch: 044/050 | Batch 300/003 | Cost: 12279.5156
Epoch: 044/050 | Batch 350/003 | Cost: 12548.7598
Epoch: 044/050 | Batch 400/003 | Cost: 12430.6104
Epoch: 044/050 | Batch 450/003 | Cost: 12775.3105
Time elapsed: 5.28 min
Epoch: 045/050 | Batch 000/003 | Cost: 12245.9482
Epoch: 045/050 | Batch 050/003 | Cost: 12547.1270
Epoch: 045/050 | Batch 100/003 | Cost: 12214.5732
Epoch: 045/050 | Batch 150/003 | Cost: 12484.7715
Epoch: 045/050 | Batch 200/003 | Cost: 12552.6123
Epoch: 045/050 | Batch 250/003 | Cost: 12510.6064
Epoch: 045/050 | Batch 300/003 | Cost: 12465.5566
Epoch: 045/050 | Batch 350/003 | Cost: 12111.4629
Epoch: 045/050 | Batch 400/003 | Cost: 12435.2578
Epoch: 045/050 | Batch 450/003 | Cost: 12433.7461
Time elapsed: 5.40 min
Epoch: 046/050 | Batch 000/003 | Cost: 12598.7734
Epoch: 046/050 | Batch 050/003 | Cost: 12568.2207
Epoch: 046/050 | Batch 100/003 | Cost: 12853.6328
Epoch: 046/050 | Batch 150/003 | Cost: 12230.1533
Epoch: 046/050 | Batch 200/003 | Cost: 12326.4727
Epoch: 046/050 | Batch 250/003 | Cost: 12770.4814
Epoch: 046/050 | Batch 300/003 | Cost: 12082.1006
Epoch: 046/050 | Batch 350/003 | Cost: 11948.0430
Epoch: 046/050 | Batch 400/003 | Cost: 12462.7773
Epoch: 046/050 | Batch 450/003 | Cost: 12145.4609
Time elapsed: 5.52 min
Epoch: 047/050 | Batch 000/003 | Cost: 12146.9893
Epoch: 047/050 | Batch 050/003 | Cost: 11917.3779
Epoch: 047/050 | Batch 100/003 | Cost: 12135.3457
Epoch: 047/050 | Batch 150/003 | Cost: 11984.8691
Epoch: 047/050 | Batch 200/003 | Cost: 12437.4248
Epoch: 047/050 | Batch 250/003 | Cost: 12444.1416
Epoch: 047/050 | Batch 300/003 | Cost: 12078.9043
Epoch: 047/050 | Batch 350/003 | Cost: 12138.3818
Epoch: 047/050 | Batch 400/003 | Cost: 12556.3086
Epoch: 047/050 | Batch 450/003 | Cost: 12726.3828
Time elapsed: 5.64 min
Epoch: 048/050 | Batch 000/003 | Cost: 12927.1035
Epoch: 048/050 | Batch 050/003 | Cost: 12747.8994
Epoch: 048/050 | Batch 100/003 | Cost: 12002.6406
Epoch: 048/050 | Batch 150/003 | Cost: 12093.2988
Epoch: 048/050 | Batch 200/003 | Cost: 11541.1982
Epoch: 048/050 | Batch 250/003 | Cost: 12530.8398
Epoch: 048/050 | Batch 300/003 | Cost: 12200.4463
Epoch: 048/050 | Batch 350/003 | Cost: 12818.4082
Epoch: 048/050 | Batch 400/003 | Cost: 12760.9844
Epoch: 048/050 | Batch 450/003 | Cost: 12186.3496
Time elapsed: 5.76 min
Epoch: 049/050 | Batch 000/003 | Cost: 12394.8848
Epoch: 049/050 | Batch 050/003 | Cost: 12406.0801
Epoch: 049/050 | Batch 100/003 | Cost: 12390.3223
Epoch: 049/050 | Batch 150/003 | Cost: 12788.7949
Epoch: 049/050 | Batch 200/003 | Cost: 12015.3936
Epoch: 049/050 | Batch 250/003 | Cost: 12259.9609
Epoch: 049/050 | Batch 300/003 | Cost: 12216.9688
Epoch: 049/050 | Batch 350/003 | Cost: 12795.6113
Epoch: 049/050 | Batch 400/003 | Cost: 11885.8955
Epoch: 049/050 | Batch 450/003 | Cost: 12445.2891
Time elapsed: 5.88 min
Epoch: 050/050 | Batch 000/003 | Cost: 12782.9785
Epoch: 050/050 | Batch 050/003 | Cost: 12291.9814
Epoch: 050/050 | Batch 100/003 | Cost: 12721.1641
Epoch: 050/050 | Batch 150/003 | Cost: 12482.5762
Epoch: 050/050 | Batch 200/003 | Cost: 12581.8125
Epoch: 050/050 | Batch 250/003 | Cost: 12422.5527
Epoch: 050/050 | Batch 300/003 | Cost: 12384.3047
Epoch: 050/050 | Batch 350/003 | Cost: 12031.4541
Epoch: 050/050 | Batch 400/003 | Cost: 12278.7617
Epoch: 050/050 | Batch 450/003 | Cost: 12190.8232
Time elapsed: 6.00 min
Total Training Time: 6.00 min

Evaluation

Reconstruction

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt

##########################
### VISUALIZATION
##########################

n_images = 15
image_width = 28

fig, axes = plt.subplots(nrows=2, ncols=n_images, 
                         sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images][:, :-num_classes]

for i in range(n_images):
    for ax, img in zip(axes, [orig_images, decoded_images]):
        curr_img = img[i].detach().to(torch.device('cpu'))
        ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')

New random-conditional images

In [7]:
for i in range(10):

    ##########################
    ### RANDOM SAMPLE
    ##########################    
    
    labels = torch.tensor([i]*10).to(device)
    n_images = labels.size()[0]
    rand_features = torch.randn(n_images, num_latent).to(device)
    new_images = model.decoder(rand_features, labels)

    ##########################
    ### VISUALIZATION
    ##########################

    image_width = 28

    fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
    decoded_images = new_images[:n_images][:, :-num_classes]

    print('Class Label %d' % i)

    for ax, img in zip(axes, decoded_images):
        curr_img = img.detach().to(torch.device('cpu'))
        ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')
        
    plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
In [8]:
%watermark -iv
numpy       1.15.4
torch       1.0.0