Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.6.8
IPython 7.2.0

torch 1.0.0
  • Runs on CPU or GPU (if available)

Model Zoo -- Convolutional Conditional Variational Autoencoder

(without labels in reconstruction loss)

A simple convolutional conditional variational autoencoder that compresses 768-pixel MNIST images down to a 50-pixel latent vector representation.

This implementation DOES NOT concatenate the inputs with the class labels when computing the reconstruction loss in contrast to how it is commonly done in non-convolutional conditional variational autoencoders. Not considering class-labels in the reconstruction loss leads to substantially better results compared to the implementation that does concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cnn-cvae.ipynb

Imports

In [2]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
##########################
### SETTINGS
##########################

# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)

# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128

# Architecture
num_classes = 10
num_features = 784
num_latent = 50


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Device: cuda:1
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [4]:
##########################
### MODEL
##########################


def to_onehot(labels, num_classes, device):

    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)

    return labels_onehot


class ConditionalVariationalAutoencoder(torch.nn.Module):

    def __init__(self, num_features, num_latent, num_classes):
        super(ConditionalVariationalAutoencoder, self).__init__()
        
        self.num_classes = num_classes
        
        
        ###############
        # ENCODER
        ##############
        
        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.enc_conv_1 = torch.nn.Conv2d(in_channels=1+self.num_classes,
                                          out_channels=16,
                                          kernel_size=(6, 6),
                                          stride=(2, 2),
                                          padding=0)

        self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,
                                          out_channels=32,
                                          kernel_size=(4, 4),
                                          stride=(2, 2),
                                          padding=0)                 
        
        self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,
                                          out_channels=64,
                                          kernel_size=(2, 2),
                                          stride=(2, 2),
                                          padding=0)                     
        
        self.z_mean = torch.nn.Linear(64*2*2, num_latent)
        # in the original paper (Kingma & Welling 2015, we use
        # have a z_mean and z_var, but the problem is that
        # the z_var can be negative, which would cause issues
        # in the log later. Hence we assume that latent vector
        # has a z_mean and z_log_var component, and when we need
        # the regular variance or std_dev, we simply use 
        # an exponential function
        self.z_log_var = torch.nn.Linear(64*2*2, num_latent)
        
        
        
        ###############
        # DECODER
        ##############
        
        self.dec_linear_1 = torch.nn.Linear(num_latent+self.num_classes, 64*2*2)
               
        self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,
                                                     out_channels=32,
                                                     kernel_size=(2, 2),
                                                     stride=(2, 2),
                                                     padding=0)
                                 
        self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,
                                                     out_channels=16,
                                                     kernel_size=(4, 4),
                                                     stride=(3, 3),
                                                     padding=1)
        
        self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,
                                                     out_channels=1,
                                                     kernel_size=(6, 6),
                                                     stride=(3, 3),
                                                     padding=4)        


    def reparameterize(self, z_mu, z_log_var):
        # Sample epsilon from standard normal distribution
        eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
        # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
        # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z
        
    def encoder(self, features, targets):
        
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        onehot_targets = onehot_targets.view(-1, self.num_classes, 1, 1)
        
        ones = torch.ones(features.size()[0], 
                          self.num_classes,
                          features.size()[2], 
                          features.size()[3], 
                          dtype=features.dtype).to(device)
        ones = ones * onehot_targets
        x = torch.cat((features, ones), dim=1)
        
        x = self.enc_conv_1(x)
        x = F.leaky_relu(x)
        #print('conv1 out:', x.size())
        
        x = self.enc_conv_2(x)
        x = F.leaky_relu(x)
        #print('conv2 out:', x.size())
        
        x = self.enc_conv_3(x)
        x = F.leaky_relu(x)
        #print('conv3 out:', x.size())
        
        z_mean = self.z_mean(x.view(-1, 64*2*2))
        z_log_var = self.z_log_var(x.view(-1, 64*2*2))
        encoded = self.reparameterize(z_mean, z_log_var)
        return z_mean, z_log_var, encoded
    
    def decoder(self, encoded, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        encoded = torch.cat((encoded, onehot_targets), dim=1)        
        
        x = self.dec_linear_1(encoded)
        x = x.view(-1, 64, 2, 2)
        
        x = self.dec_deconv_1(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        x = self.dec_deconv_2(x)
        x = F.leaky_relu(x)
        #print('deconv2 out:', x.size())
        
        x = self.dec_deconv_3(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, features, targets):
        
        z_mean, z_log_var, encoded = self.encoder(features, targets)
        decoded = self.decoder(encoded, targets)
        
        return z_mean, z_log_var, encoded, decoded

    
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
                                          num_latent,
                                          num_classes)
model = model.to(device)
    

##########################
### COST AND OPTIMIZER
##########################

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Training

In [5]:
start_time = time.time()

for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(device)
        targets = targets.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features, targets)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        
        
        ### Add condition
        # Disabled for reconstruction loss as it gives poor results
        """
        onehot_targets = to_onehot(targets, num_classes, device)
        onehot_targets = onehot_targets.view(-1, num_classes, 1, 1)
        
        ones = torch.ones(features.size()[0], 
                          num_classes,
                          features.size()[2], 
                          features.size()[3], 
                          dtype=features.dtype).to(device)
        ones = ones * onehot_targets
        x_con = torch.cat((features, ones), dim=1)
        """
        
        ### Compute loss
        #pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
        pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        ### UPDATE MODEL PARAMETERS
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader), cost))
            
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 69516.4531
Epoch: 001/050 | Batch 050/469 | Cost: 65598.3438
Epoch: 001/050 | Batch 100/469 | Cost: 38653.3398
Epoch: 001/050 | Batch 150/469 | Cost: 31647.0234
Epoch: 001/050 | Batch 200/469 | Cost: 29562.4902
Epoch: 001/050 | Batch 250/469 | Cost: 27448.7969
Epoch: 001/050 | Batch 300/469 | Cost: 27149.4609
Epoch: 001/050 | Batch 350/469 | Cost: 25922.6641
Epoch: 001/050 | Batch 400/469 | Cost: 25652.2539
Epoch: 001/050 | Batch 450/469 | Cost: 25503.6504
Time elapsed: 0.15 min
Epoch: 002/050 | Batch 000/469 | Cost: 25003.4570
Epoch: 002/050 | Batch 050/469 | Cost: 24463.9062
Epoch: 002/050 | Batch 100/469 | Cost: 24639.4102
Epoch: 002/050 | Batch 150/469 | Cost: 25084.6973
Epoch: 002/050 | Batch 200/469 | Cost: 24361.1699
Epoch: 002/050 | Batch 250/469 | Cost: 23067.3418
Epoch: 002/050 | Batch 300/469 | Cost: 23340.0000
Epoch: 002/050 | Batch 350/469 | Cost: 22182.6523
Epoch: 002/050 | Batch 400/469 | Cost: 23848.3574
Epoch: 002/050 | Batch 450/469 | Cost: 22587.1523
Time elapsed: 0.29 min
Epoch: 003/050 | Batch 000/469 | Cost: 23024.4629
Epoch: 003/050 | Batch 050/469 | Cost: 22733.5449
Epoch: 003/050 | Batch 100/469 | Cost: 21335.1270
Epoch: 003/050 | Batch 150/469 | Cost: 22196.0801
Epoch: 003/050 | Batch 200/469 | Cost: 21967.5898
Epoch: 003/050 | Batch 250/469 | Cost: 21690.1719
Epoch: 003/050 | Batch 300/469 | Cost: 21659.4297
Epoch: 003/050 | Batch 350/469 | Cost: 20510.7891
Epoch: 003/050 | Batch 400/469 | Cost: 21242.7695
Epoch: 003/050 | Batch 450/469 | Cost: 20785.4922
Time elapsed: 0.44 min
Epoch: 004/050 | Batch 000/469 | Cost: 20328.3145
Epoch: 004/050 | Batch 050/469 | Cost: 20273.4453
Epoch: 004/050 | Batch 100/469 | Cost: 20539.0879
Epoch: 004/050 | Batch 150/469 | Cost: 20137.0156
Epoch: 004/050 | Batch 200/469 | Cost: 19641.2148
Epoch: 004/050 | Batch 250/469 | Cost: 20138.8418
Epoch: 004/050 | Batch 300/469 | Cost: 18882.3086
Epoch: 004/050 | Batch 350/469 | Cost: 19263.8516
Epoch: 004/050 | Batch 400/469 | Cost: 19991.0703
Epoch: 004/050 | Batch 450/469 | Cost: 18806.1582
Time elapsed: 0.58 min
Epoch: 005/050 | Batch 000/469 | Cost: 19555.3555
Epoch: 005/050 | Batch 050/469 | Cost: 19259.2910
Epoch: 005/050 | Batch 100/469 | Cost: 18361.7383
Epoch: 005/050 | Batch 150/469 | Cost: 18087.8281
Epoch: 005/050 | Batch 200/469 | Cost: 18091.0488
Epoch: 005/050 | Batch 250/469 | Cost: 19610.8906
Epoch: 005/050 | Batch 300/469 | Cost: 17971.9766
Epoch: 005/050 | Batch 350/469 | Cost: 18295.7988
Epoch: 005/050 | Batch 400/469 | Cost: 18726.7246
Epoch: 005/050 | Batch 450/469 | Cost: 17738.4590
Time elapsed: 0.73 min
Epoch: 006/050 | Batch 000/469 | Cost: 18221.9043
Epoch: 006/050 | Batch 050/469 | Cost: 18209.2578
Epoch: 006/050 | Batch 100/469 | Cost: 17747.5586
Epoch: 006/050 | Batch 150/469 | Cost: 16824.4941
Epoch: 006/050 | Batch 200/469 | Cost: 17604.7949
Epoch: 006/050 | Batch 250/469 | Cost: 17186.1855
Epoch: 006/050 | Batch 300/469 | Cost: 17752.4570
Epoch: 006/050 | Batch 350/469 | Cost: 17146.1660
Epoch: 006/050 | Batch 400/469 | Cost: 17564.4121
Epoch: 006/050 | Batch 450/469 | Cost: 16982.0527
Time elapsed: 0.88 min
Epoch: 007/050 | Batch 000/469 | Cost: 17229.5156
Epoch: 007/050 | Batch 050/469 | Cost: 17952.2988
Epoch: 007/050 | Batch 100/469 | Cost: 17679.9102
Epoch: 007/050 | Batch 150/469 | Cost: 16431.1602
Epoch: 007/050 | Batch 200/469 | Cost: 16707.6699
Epoch: 007/050 | Batch 250/469 | Cost: 16626.2344
Epoch: 007/050 | Batch 300/469 | Cost: 17091.8594
Epoch: 007/050 | Batch 350/469 | Cost: 16410.7461
Epoch: 007/050 | Batch 400/469 | Cost: 16464.0039
Epoch: 007/050 | Batch 450/469 | Cost: 17185.2910
Time elapsed: 1.03 min
Epoch: 008/050 | Batch 000/469 | Cost: 16310.5146
Epoch: 008/050 | Batch 050/469 | Cost: 16510.9883
Epoch: 008/050 | Batch 100/469 | Cost: 16409.1504
Epoch: 008/050 | Batch 150/469 | Cost: 16645.9414
Epoch: 008/050 | Batch 200/469 | Cost: 16140.7637
Epoch: 008/050 | Batch 250/469 | Cost: 16261.8232
Epoch: 008/050 | Batch 300/469 | Cost: 15731.7832
Epoch: 008/050 | Batch 350/469 | Cost: 16438.5391
Epoch: 008/050 | Batch 400/469 | Cost: 16522.8516
Epoch: 008/050 | Batch 450/469 | Cost: 16674.2656
Time elapsed: 1.18 min
Epoch: 009/050 | Batch 000/469 | Cost: 15663.1904
Epoch: 009/050 | Batch 050/469 | Cost: 15857.3770
Epoch: 009/050 | Batch 100/469 | Cost: 16233.4707
Epoch: 009/050 | Batch 150/469 | Cost: 16635.9785
Epoch: 009/050 | Batch 200/469 | Cost: 16294.5547
Epoch: 009/050 | Batch 250/469 | Cost: 15947.5801
Epoch: 009/050 | Batch 300/469 | Cost: 16139.6113
Epoch: 009/050 | Batch 350/469 | Cost: 16081.8906
Epoch: 009/050 | Batch 400/469 | Cost: 16331.2500
Epoch: 009/050 | Batch 450/469 | Cost: 15352.7773
Time elapsed: 1.32 min
Epoch: 010/050 | Batch 000/469 | Cost: 15708.1094
Epoch: 010/050 | Batch 050/469 | Cost: 16146.4141
Epoch: 010/050 | Batch 100/469 | Cost: 16003.0078
Epoch: 010/050 | Batch 150/469 | Cost: 14990.0430
Epoch: 010/050 | Batch 200/469 | Cost: 15628.8242
Epoch: 010/050 | Batch 250/469 | Cost: 15949.3691
Epoch: 010/050 | Batch 300/469 | Cost: 14936.1875
Epoch: 010/050 | Batch 350/469 | Cost: 15196.0625
Epoch: 010/050 | Batch 400/469 | Cost: 15594.7686
Epoch: 010/050 | Batch 450/469 | Cost: 16577.6230
Time elapsed: 1.47 min
Epoch: 011/050 | Batch 000/469 | Cost: 15626.6035
Epoch: 011/050 | Batch 050/469 | Cost: 15766.5859
Epoch: 011/050 | Batch 100/469 | Cost: 16134.7734
Epoch: 011/050 | Batch 150/469 | Cost: 15299.3574
Epoch: 011/050 | Batch 200/469 | Cost: 15611.9248
Epoch: 011/050 | Batch 250/469 | Cost: 16024.9580
Epoch: 011/050 | Batch 300/469 | Cost: 15840.3047
Epoch: 011/050 | Batch 350/469 | Cost: 15214.9883
Epoch: 011/050 | Batch 400/469 | Cost: 15782.8574
Epoch: 011/050 | Batch 450/469 | Cost: 15268.2646
Time elapsed: 1.62 min
Epoch: 012/050 | Batch 000/469 | Cost: 15584.0322
Epoch: 012/050 | Batch 050/469 | Cost: 15752.2871
Epoch: 012/050 | Batch 100/469 | Cost: 15630.2695
Epoch: 012/050 | Batch 150/469 | Cost: 15513.7822
Epoch: 012/050 | Batch 200/469 | Cost: 15230.1543
Epoch: 012/050 | Batch 250/469 | Cost: 14979.2441
Epoch: 012/050 | Batch 300/469 | Cost: 15661.2822
Epoch: 012/050 | Batch 350/469 | Cost: 14583.6562
Epoch: 012/050 | Batch 400/469 | Cost: 15692.2637
Epoch: 012/050 | Batch 450/469 | Cost: 15444.3398
Time elapsed: 1.76 min
Epoch: 013/050 | Batch 000/469 | Cost: 15403.8340
Epoch: 013/050 | Batch 050/469 | Cost: 15334.1270
Epoch: 013/050 | Batch 100/469 | Cost: 14869.9678
Epoch: 013/050 | Batch 150/469 | Cost: 14967.7734
Epoch: 013/050 | Batch 200/469 | Cost: 15038.3467
Epoch: 013/050 | Batch 250/469 | Cost: 14940.0566
Epoch: 013/050 | Batch 300/469 | Cost: 15861.9902
Epoch: 013/050 | Batch 350/469 | Cost: 15016.0215
Epoch: 013/050 | Batch 400/469 | Cost: 15020.5508
Epoch: 013/050 | Batch 450/469 | Cost: 14558.9678
Time elapsed: 1.91 min
Epoch: 014/050 | Batch 000/469 | Cost: 15571.9609
Epoch: 014/050 | Batch 050/469 | Cost: 14986.8027
Epoch: 014/050 | Batch 100/469 | Cost: 14748.6660
Epoch: 014/050 | Batch 150/469 | Cost: 15177.5010
Epoch: 014/050 | Batch 200/469 | Cost: 15166.5283
Epoch: 014/050 | Batch 250/469 | Cost: 14866.0449
Epoch: 014/050 | Batch 300/469 | Cost: 15227.5977
Epoch: 014/050 | Batch 350/469 | Cost: 15148.1973
Epoch: 014/050 | Batch 400/469 | Cost: 15003.9395
Epoch: 014/050 | Batch 450/469 | Cost: 15571.9531
Time elapsed: 2.06 min
Epoch: 015/050 | Batch 000/469 | Cost: 15100.7773
Epoch: 015/050 | Batch 050/469 | Cost: 14556.3730
Epoch: 015/050 | Batch 100/469 | Cost: 15114.8965
Epoch: 015/050 | Batch 150/469 | Cost: 15237.2412
Epoch: 015/050 | Batch 200/469 | Cost: 15173.2842
Epoch: 015/050 | Batch 250/469 | Cost: 15283.6016
Epoch: 015/050 | Batch 300/469 | Cost: 14717.9834
Epoch: 015/050 | Batch 350/469 | Cost: 15098.9512
Epoch: 015/050 | Batch 400/469 | Cost: 14483.3516
Epoch: 015/050 | Batch 450/469 | Cost: 14874.4346
Time elapsed: 2.21 min
Epoch: 016/050 | Batch 000/469 | Cost: 14778.4883
Epoch: 016/050 | Batch 050/469 | Cost: 14571.5068
Epoch: 016/050 | Batch 100/469 | Cost: 14361.2773
Epoch: 016/050 | Batch 150/469 | Cost: 14580.1055
Epoch: 016/050 | Batch 200/469 | Cost: 14950.4766
Epoch: 016/050 | Batch 250/469 | Cost: 14357.0742
Epoch: 016/050 | Batch 300/469 | Cost: 15067.2119
Epoch: 016/050 | Batch 350/469 | Cost: 14431.0293
Epoch: 016/050 | Batch 400/469 | Cost: 15010.4941
Epoch: 016/050 | Batch 450/469 | Cost: 14981.4385
Time elapsed: 2.35 min
Epoch: 017/050 | Batch 000/469 | Cost: 14213.7207
Epoch: 017/050 | Batch 050/469 | Cost: 14254.8223
Epoch: 017/050 | Batch 100/469 | Cost: 14608.7031
Epoch: 017/050 | Batch 150/469 | Cost: 14804.6738
Epoch: 017/050 | Batch 200/469 | Cost: 15223.3574
Epoch: 017/050 | Batch 250/469 | Cost: 15073.8105
Epoch: 017/050 | Batch 300/469 | Cost: 14488.2256
Epoch: 017/050 | Batch 350/469 | Cost: 15285.3438
Epoch: 017/050 | Batch 400/469 | Cost: 14768.0410
Epoch: 017/050 | Batch 450/469 | Cost: 14246.4082
Time elapsed: 2.49 min
Epoch: 018/050 | Batch 000/469 | Cost: 14446.7607
Epoch: 018/050 | Batch 050/469 | Cost: 14307.9512
Epoch: 018/050 | Batch 100/469 | Cost: 14979.2393
Epoch: 018/050 | Batch 150/469 | Cost: 14640.7529
Epoch: 018/050 | Batch 200/469 | Cost: 14336.5176
Epoch: 018/050 | Batch 250/469 | Cost: 14856.0244
Epoch: 018/050 | Batch 300/469 | Cost: 14236.4883
Epoch: 018/050 | Batch 350/469 | Cost: 14293.7129
Epoch: 018/050 | Batch 400/469 | Cost: 14989.7578
Epoch: 018/050 | Batch 450/469 | Cost: 14645.5918
Time elapsed: 2.63 min
Epoch: 019/050 | Batch 000/469 | Cost: 14769.7305
Epoch: 019/050 | Batch 050/469 | Cost: 14644.3301
Epoch: 019/050 | Batch 100/469 | Cost: 14153.6289
Epoch: 019/050 | Batch 150/469 | Cost: 15014.8457
Epoch: 019/050 | Batch 200/469 | Cost: 14531.8291
Epoch: 019/050 | Batch 250/469 | Cost: 14103.4414
Epoch: 019/050 | Batch 300/469 | Cost: 14499.4141
Epoch: 019/050 | Batch 350/469 | Cost: 14517.2227
Epoch: 019/050 | Batch 400/469 | Cost: 14708.0664
Epoch: 019/050 | Batch 450/469 | Cost: 14042.7529
Time elapsed: 2.77 min
Epoch: 020/050 | Batch 000/469 | Cost: 15051.2266
Epoch: 020/050 | Batch 050/469 | Cost: 14537.1982
Epoch: 020/050 | Batch 100/469 | Cost: 13989.1104
Epoch: 020/050 | Batch 150/469 | Cost: 14822.6094
Epoch: 020/050 | Batch 200/469 | Cost: 15177.9668
Epoch: 020/050 | Batch 250/469 | Cost: 14710.3174
Epoch: 020/050 | Batch 300/469 | Cost: 13794.1641
Epoch: 020/050 | Batch 350/469 | Cost: 14262.4473
Epoch: 020/050 | Batch 400/469 | Cost: 14950.7432
Epoch: 020/050 | Batch 450/469 | Cost: 14864.3555
Time elapsed: 2.91 min
Epoch: 021/050 | Batch 000/469 | Cost: 15020.9473
Epoch: 021/050 | Batch 050/469 | Cost: 14729.3340
Epoch: 021/050 | Batch 100/469 | Cost: 14100.7500
Epoch: 021/050 | Batch 150/469 | Cost: 14151.6641
Epoch: 021/050 | Batch 200/469 | Cost: 14153.0459
Epoch: 021/050 | Batch 250/469 | Cost: 14365.5645
Epoch: 021/050 | Batch 300/469 | Cost: 14539.5244
Epoch: 021/050 | Batch 350/469 | Cost: 14018.8398
Epoch: 021/050 | Batch 400/469 | Cost: 14032.9209
Epoch: 021/050 | Batch 450/469 | Cost: 13872.8320
Time elapsed: 3.06 min
Epoch: 022/050 | Batch 000/469 | Cost: 14742.1719
Epoch: 022/050 | Batch 050/469 | Cost: 14320.2646
Epoch: 022/050 | Batch 100/469 | Cost: 14856.3320
Epoch: 022/050 | Batch 150/469 | Cost: 14376.0273
Epoch: 022/050 | Batch 200/469 | Cost: 14115.4121
Epoch: 022/050 | Batch 250/469 | Cost: 13767.6973
Epoch: 022/050 | Batch 300/469 | Cost: 13885.6768
Epoch: 022/050 | Batch 350/469 | Cost: 15135.5273
Epoch: 022/050 | Batch 400/469 | Cost: 14869.7598
Epoch: 022/050 | Batch 450/469 | Cost: 13792.0283
Time elapsed: 3.20 min
Epoch: 023/050 | Batch 000/469 | Cost: 14404.2324
Epoch: 023/050 | Batch 050/469 | Cost: 14076.9844
Epoch: 023/050 | Batch 100/469 | Cost: 14239.1904
Epoch: 023/050 | Batch 150/469 | Cost: 14376.3242
Epoch: 023/050 | Batch 200/469 | Cost: 13941.0156
Epoch: 023/050 | Batch 250/469 | Cost: 13948.4395
Epoch: 023/050 | Batch 300/469 | Cost: 15119.5137
Epoch: 023/050 | Batch 350/469 | Cost: 14480.1211
Epoch: 023/050 | Batch 400/469 | Cost: 14310.3594
Epoch: 023/050 | Batch 450/469 | Cost: 14712.5039
Time elapsed: 3.34 min
Epoch: 024/050 | Batch 000/469 | Cost: 14535.5488
Epoch: 024/050 | Batch 050/469 | Cost: 14241.1660
Epoch: 024/050 | Batch 100/469 | Cost: 14769.8477
Epoch: 024/050 | Batch 150/469 | Cost: 15056.7559
Epoch: 024/050 | Batch 200/469 | Cost: 14387.6484
Epoch: 024/050 | Batch 250/469 | Cost: 14316.7148
Epoch: 024/050 | Batch 300/469 | Cost: 14848.7793
Epoch: 024/050 | Batch 350/469 | Cost: 14909.2490
Epoch: 024/050 | Batch 400/469 | Cost: 14848.7090
Epoch: 024/050 | Batch 450/469 | Cost: 14461.7627
Time elapsed: 3.48 min
Epoch: 025/050 | Batch 000/469 | Cost: 14212.7168
Epoch: 025/050 | Batch 050/469 | Cost: 14333.6973
Epoch: 025/050 | Batch 100/469 | Cost: 14074.0586
Epoch: 025/050 | Batch 150/469 | Cost: 14331.3789
Epoch: 025/050 | Batch 200/469 | Cost: 13657.7471
Epoch: 025/050 | Batch 250/469 | Cost: 14190.0117
Epoch: 025/050 | Batch 300/469 | Cost: 13733.5908
Epoch: 025/050 | Batch 350/469 | Cost: 14021.1602
Epoch: 025/050 | Batch 400/469 | Cost: 13840.4336
Epoch: 025/050 | Batch 450/469 | Cost: 14060.3848
Time elapsed: 3.63 min
Epoch: 026/050 | Batch 000/469 | Cost: 15362.9629
Epoch: 026/050 | Batch 050/469 | Cost: 14140.0303
Epoch: 026/050 | Batch 100/469 | Cost: 13597.3838
Epoch: 026/050 | Batch 150/469 | Cost: 14821.4492
Epoch: 026/050 | Batch 200/469 | Cost: 14879.7930
Epoch: 026/050 | Batch 250/469 | Cost: 14080.9072
Epoch: 026/050 | Batch 300/469 | Cost: 14645.4023
Epoch: 026/050 | Batch 350/469 | Cost: 13696.6152
Epoch: 026/050 | Batch 400/469 | Cost: 14472.7656
Epoch: 026/050 | Batch 450/469 | Cost: 14059.6641
Time elapsed: 3.78 min
Epoch: 027/050 | Batch 000/469 | Cost: 14369.2246
Epoch: 027/050 | Batch 050/469 | Cost: 13632.5137
Epoch: 027/050 | Batch 100/469 | Cost: 13472.9004
Epoch: 027/050 | Batch 150/469 | Cost: 13673.4121
Epoch: 027/050 | Batch 200/469 | Cost: 14124.0625
Epoch: 027/050 | Batch 250/469 | Cost: 13920.0332
Epoch: 027/050 | Batch 300/469 | Cost: 13909.5391
Epoch: 027/050 | Batch 350/469 | Cost: 14398.0977
Epoch: 027/050 | Batch 400/469 | Cost: 14438.4854
Epoch: 027/050 | Batch 450/469 | Cost: 14019.9814
Time elapsed: 3.93 min
Epoch: 028/050 | Batch 000/469 | Cost: 14063.9189
Epoch: 028/050 | Batch 050/469 | Cost: 14298.8477
Epoch: 028/050 | Batch 100/469 | Cost: 13534.4980
Epoch: 028/050 | Batch 150/469 | Cost: 13799.8779
Epoch: 028/050 | Batch 200/469 | Cost: 13730.7334
Epoch: 028/050 | Batch 250/469 | Cost: 13006.5938
Epoch: 028/050 | Batch 300/469 | Cost: 14268.8652
Epoch: 028/050 | Batch 350/469 | Cost: 13673.4648
Epoch: 028/050 | Batch 400/469 | Cost: 13597.6719
Epoch: 028/050 | Batch 450/469 | Cost: 13925.3242
Time elapsed: 4.08 min
Epoch: 029/050 | Batch 000/469 | Cost: 14032.7266
Epoch: 029/050 | Batch 050/469 | Cost: 14527.6777
Epoch: 029/050 | Batch 100/469 | Cost: 14219.7266
Epoch: 029/050 | Batch 150/469 | Cost: 13933.3320
Epoch: 029/050 | Batch 200/469 | Cost: 14406.4668
Epoch: 029/050 | Batch 250/469 | Cost: 13692.3379
Epoch: 029/050 | Batch 300/469 | Cost: 13557.2705
Epoch: 029/050 | Batch 350/469 | Cost: 14528.8633
Epoch: 029/050 | Batch 400/469 | Cost: 14413.3438
Epoch: 029/050 | Batch 450/469 | Cost: 14293.6504
Time elapsed: 4.23 min
Epoch: 030/050 | Batch 000/469 | Cost: 14673.0938
Epoch: 030/050 | Batch 050/469 | Cost: 14199.3184
Epoch: 030/050 | Batch 100/469 | Cost: 14027.1729
Epoch: 030/050 | Batch 150/469 | Cost: 14117.5713
Epoch: 030/050 | Batch 200/469 | Cost: 13543.0605
Epoch: 030/050 | Batch 250/469 | Cost: 14418.0820
Epoch: 030/050 | Batch 300/469 | Cost: 13932.8691
Epoch: 030/050 | Batch 350/469 | Cost: 13475.8350
Epoch: 030/050 | Batch 400/469 | Cost: 14393.7646
Epoch: 030/050 | Batch 450/469 | Cost: 14195.9902
Time elapsed: 4.37 min
Epoch: 031/050 | Batch 000/469 | Cost: 13865.0762
Epoch: 031/050 | Batch 050/469 | Cost: 13816.7061
Epoch: 031/050 | Batch 100/469 | Cost: 13752.8525
Epoch: 031/050 | Batch 150/469 | Cost: 14141.7930
Epoch: 031/050 | Batch 200/469 | Cost: 14415.1172
Epoch: 031/050 | Batch 250/469 | Cost: 13907.3770
Epoch: 031/050 | Batch 300/469 | Cost: 13910.6807
Epoch: 031/050 | Batch 350/469 | Cost: 13633.5596
Epoch: 031/050 | Batch 400/469 | Cost: 13621.3359
Epoch: 031/050 | Batch 450/469 | Cost: 13538.8291
Time elapsed: 4.52 min
Epoch: 032/050 | Batch 000/469 | Cost: 14009.0742
Epoch: 032/050 | Batch 050/469 | Cost: 13491.7461
Epoch: 032/050 | Batch 100/469 | Cost: 13270.1104
Epoch: 032/050 | Batch 150/469 | Cost: 14276.8320
Epoch: 032/050 | Batch 200/469 | Cost: 13928.1875
Epoch: 032/050 | Batch 250/469 | Cost: 13973.2520
Epoch: 032/050 | Batch 300/469 | Cost: 14112.7969
Epoch: 032/050 | Batch 350/469 | Cost: 14247.1250
Epoch: 032/050 | Batch 400/469 | Cost: 14020.4355
Epoch: 032/050 | Batch 450/469 | Cost: 13671.0029
Time elapsed: 4.67 min
Epoch: 033/050 | Batch 000/469 | Cost: 14114.7676
Epoch: 033/050 | Batch 050/469 | Cost: 14096.6172
Epoch: 033/050 | Batch 100/469 | Cost: 14510.5137
Epoch: 033/050 | Batch 150/469 | Cost: 14087.4746
Epoch: 033/050 | Batch 200/469 | Cost: 13874.9834
Epoch: 033/050 | Batch 250/469 | Cost: 14145.5840
Epoch: 033/050 | Batch 300/469 | Cost: 13861.3926
Epoch: 033/050 | Batch 350/469 | Cost: 14629.8486
Epoch: 033/050 | Batch 400/469 | Cost: 14538.3857
Epoch: 033/050 | Batch 450/469 | Cost: 13830.5381
Time elapsed: 4.82 min
Epoch: 034/050 | Batch 000/469 | Cost: 13836.5195
Epoch: 034/050 | Batch 050/469 | Cost: 13860.2246
Epoch: 034/050 | Batch 100/469 | Cost: 14087.6016
Epoch: 034/050 | Batch 150/469 | Cost: 14019.4785
Epoch: 034/050 | Batch 200/469 | Cost: 13451.0508
Epoch: 034/050 | Batch 250/469 | Cost: 13142.4326
Epoch: 034/050 | Batch 300/469 | Cost: 14079.7734
Epoch: 034/050 | Batch 350/469 | Cost: 13413.0859
Epoch: 034/050 | Batch 400/469 | Cost: 14405.9668
Epoch: 034/050 | Batch 450/469 | Cost: 14408.2139
Time elapsed: 4.97 min
Epoch: 035/050 | Batch 000/469 | Cost: 13902.5938
Epoch: 035/050 | Batch 050/469 | Cost: 13920.2412
Epoch: 035/050 | Batch 100/469 | Cost: 13912.0137
Epoch: 035/050 | Batch 150/469 | Cost: 13720.4482
Epoch: 035/050 | Batch 200/469 | Cost: 13858.9121
Epoch: 035/050 | Batch 250/469 | Cost: 13355.0986
Epoch: 035/050 | Batch 300/469 | Cost: 13733.6855
Epoch: 035/050 | Batch 350/469 | Cost: 14387.2490
Epoch: 035/050 | Batch 400/469 | Cost: 14289.1094
Epoch: 035/050 | Batch 450/469 | Cost: 13157.4883
Time elapsed: 5.11 min
Epoch: 036/050 | Batch 000/469 | Cost: 13923.4131
Epoch: 036/050 | Batch 050/469 | Cost: 13152.2998
Epoch: 036/050 | Batch 100/469 | Cost: 13996.1729
Epoch: 036/050 | Batch 150/469 | Cost: 13884.8965
Epoch: 036/050 | Batch 200/469 | Cost: 13887.7607
Epoch: 036/050 | Batch 250/469 | Cost: 13652.5996
Epoch: 036/050 | Batch 300/469 | Cost: 13951.4346
Epoch: 036/050 | Batch 350/469 | Cost: 13787.7617
Epoch: 036/050 | Batch 400/469 | Cost: 14097.5078
Epoch: 036/050 | Batch 450/469 | Cost: 13684.4854
Time elapsed: 5.26 min
Epoch: 037/050 | Batch 000/469 | Cost: 14580.7109
Epoch: 037/050 | Batch 050/469 | Cost: 13706.5557
Epoch: 037/050 | Batch 100/469 | Cost: 14079.7070
Epoch: 037/050 | Batch 150/469 | Cost: 14231.3975
Epoch: 037/050 | Batch 200/469 | Cost: 13724.7275
Epoch: 037/050 | Batch 250/469 | Cost: 14127.0488
Epoch: 037/050 | Batch 300/469 | Cost: 14432.3828
Epoch: 037/050 | Batch 350/469 | Cost: 13770.9668
Epoch: 037/050 | Batch 400/469 | Cost: 14457.6172
Epoch: 037/050 | Batch 450/469 | Cost: 13425.8623
Time elapsed: 5.41 min
Epoch: 038/050 | Batch 000/469 | Cost: 13763.5371
Epoch: 038/050 | Batch 050/469 | Cost: 13891.8945
Epoch: 038/050 | Batch 100/469 | Cost: 13626.1357
Epoch: 038/050 | Batch 150/469 | Cost: 14679.0449
Epoch: 038/050 | Batch 200/469 | Cost: 13221.4004
Epoch: 038/050 | Batch 250/469 | Cost: 13140.2148
Epoch: 038/050 | Batch 300/469 | Cost: 13809.6084
Epoch: 038/050 | Batch 350/469 | Cost: 13575.6592
Epoch: 038/050 | Batch 400/469 | Cost: 14249.9180
Epoch: 038/050 | Batch 450/469 | Cost: 14097.8291
Time elapsed: 5.56 min
Epoch: 039/050 | Batch 000/469 | Cost: 14015.1768
Epoch: 039/050 | Batch 050/469 | Cost: 13973.9795
Epoch: 039/050 | Batch 100/469 | Cost: 13633.8730
Epoch: 039/050 | Batch 150/469 | Cost: 14055.6895
Epoch: 039/050 | Batch 200/469 | Cost: 13871.2949
Epoch: 039/050 | Batch 250/469 | Cost: 13746.9258
Epoch: 039/050 | Batch 300/469 | Cost: 13203.3242
Epoch: 039/050 | Batch 350/469 | Cost: 13911.6846
Epoch: 039/050 | Batch 400/469 | Cost: 14241.5703
Epoch: 039/050 | Batch 450/469 | Cost: 13677.2559
Time elapsed: 5.70 min
Epoch: 040/050 | Batch 000/469 | Cost: 14490.0547
Epoch: 040/050 | Batch 050/469 | Cost: 13689.6680
Epoch: 040/050 | Batch 100/469 | Cost: 14046.6895
Epoch: 040/050 | Batch 150/469 | Cost: 13632.8125
Epoch: 040/050 | Batch 200/469 | Cost: 13456.0918
Epoch: 040/050 | Batch 250/469 | Cost: 13832.4795
Epoch: 040/050 | Batch 300/469 | Cost: 13813.2939
Epoch: 040/050 | Batch 350/469 | Cost: 13484.2520
Epoch: 040/050 | Batch 400/469 | Cost: 13600.7803
Epoch: 040/050 | Batch 450/469 | Cost: 13492.7578
Time elapsed: 5.85 min
Epoch: 041/050 | Batch 000/469 | Cost: 13993.5547
Epoch: 041/050 | Batch 050/469 | Cost: 13833.7031
Epoch: 041/050 | Batch 100/469 | Cost: 13798.5264
Epoch: 041/050 | Batch 150/469 | Cost: 14379.4717
Epoch: 041/050 | Batch 200/469 | Cost: 13919.1445
Epoch: 041/050 | Batch 250/469 | Cost: 13361.4160
Epoch: 041/050 | Batch 300/469 | Cost: 14154.9043
Epoch: 041/050 | Batch 350/469 | Cost: 13858.2715
Epoch: 041/050 | Batch 400/469 | Cost: 14078.7451
Epoch: 041/050 | Batch 450/469 | Cost: 13970.0488
Time elapsed: 6.00 min
Epoch: 042/050 | Batch 000/469 | Cost: 14093.0371
Epoch: 042/050 | Batch 050/469 | Cost: 14073.4688
Epoch: 042/050 | Batch 100/469 | Cost: 13645.2754
Epoch: 042/050 | Batch 150/469 | Cost: 13464.0029
Epoch: 042/050 | Batch 200/469 | Cost: 13615.8643
Epoch: 042/050 | Batch 250/469 | Cost: 13301.9805
Epoch: 042/050 | Batch 300/469 | Cost: 13605.0020
Epoch: 042/050 | Batch 350/469 | Cost: 14035.0498
Epoch: 042/050 | Batch 400/469 | Cost: 13637.4297
Epoch: 042/050 | Batch 450/469 | Cost: 14165.7686
Time elapsed: 6.15 min
Epoch: 043/050 | Batch 000/469 | Cost: 13715.1055
Epoch: 043/050 | Batch 050/469 | Cost: 14122.5898
Epoch: 043/050 | Batch 100/469 | Cost: 14184.3633
Epoch: 043/050 | Batch 150/469 | Cost: 13745.1133
Epoch: 043/050 | Batch 200/469 | Cost: 13448.2559
Epoch: 043/050 | Batch 250/469 | Cost: 13323.3438
Epoch: 043/050 | Batch 300/469 | Cost: 13835.5723
Epoch: 043/050 | Batch 350/469 | Cost: 13462.5098
Epoch: 043/050 | Batch 400/469 | Cost: 14195.2227
Epoch: 043/050 | Batch 450/469 | Cost: 13253.4600
Time elapsed: 6.30 min
Epoch: 044/050 | Batch 000/469 | Cost: 14028.9277
Epoch: 044/050 | Batch 050/469 | Cost: 13369.4111
Epoch: 044/050 | Batch 100/469 | Cost: 13645.9971
Epoch: 044/050 | Batch 150/469 | Cost: 13864.8613
Epoch: 044/050 | Batch 200/469 | Cost: 13508.7471
Epoch: 044/050 | Batch 250/469 | Cost: 14534.7754
Epoch: 044/050 | Batch 300/469 | Cost: 13565.7900
Epoch: 044/050 | Batch 350/469 | Cost: 13719.3438
Epoch: 044/050 | Batch 400/469 | Cost: 13678.1367
Epoch: 044/050 | Batch 450/469 | Cost: 14057.3779
Time elapsed: 6.44 min
Epoch: 045/050 | Batch 000/469 | Cost: 13414.4121
Epoch: 045/050 | Batch 050/469 | Cost: 13531.8555
Epoch: 045/050 | Batch 100/469 | Cost: 13470.2266
Epoch: 045/050 | Batch 150/469 | Cost: 13866.7627
Epoch: 045/050 | Batch 200/469 | Cost: 13438.2832
Epoch: 045/050 | Batch 250/469 | Cost: 14194.3691
Epoch: 045/050 | Batch 300/469 | Cost: 14172.3320
Epoch: 045/050 | Batch 350/469 | Cost: 13798.1680
Epoch: 045/050 | Batch 400/469 | Cost: 13684.1064
Epoch: 045/050 | Batch 450/469 | Cost: 13255.7441
Time elapsed: 6.59 min
Epoch: 046/050 | Batch 000/469 | Cost: 13833.5371
Epoch: 046/050 | Batch 050/469 | Cost: 13982.0898
Epoch: 046/050 | Batch 100/469 | Cost: 13699.0674
Epoch: 046/050 | Batch 150/469 | Cost: 13579.7803
Epoch: 046/050 | Batch 200/469 | Cost: 13611.3682
Epoch: 046/050 | Batch 250/469 | Cost: 14532.4092
Epoch: 046/050 | Batch 300/469 | Cost: 13690.0381
Epoch: 046/050 | Batch 350/469 | Cost: 13886.2227
Epoch: 046/050 | Batch 400/469 | Cost: 13716.4883
Epoch: 046/050 | Batch 450/469 | Cost: 13887.5723
Time elapsed: 6.74 min
Epoch: 047/050 | Batch 000/469 | Cost: 13460.0312
Epoch: 047/050 | Batch 050/469 | Cost: 13862.8320
Epoch: 047/050 | Batch 100/469 | Cost: 13045.7754
Epoch: 047/050 | Batch 150/469 | Cost: 13520.7910
Epoch: 047/050 | Batch 200/469 | Cost: 13966.8848
Epoch: 047/050 | Batch 250/469 | Cost: 14337.5615
Epoch: 047/050 | Batch 300/469 | Cost: 13835.9805
Epoch: 047/050 | Batch 350/469 | Cost: 13705.6699
Epoch: 047/050 | Batch 400/469 | Cost: 14085.0215
Epoch: 047/050 | Batch 450/469 | Cost: 13708.9961
Time elapsed: 6.89 min
Epoch: 048/050 | Batch 000/469 | Cost: 13683.8477
Epoch: 048/050 | Batch 050/469 | Cost: 14290.2441
Epoch: 048/050 | Batch 100/469 | Cost: 13824.9033
Epoch: 048/050 | Batch 150/469 | Cost: 13902.4424
Epoch: 048/050 | Batch 200/469 | Cost: 13742.8066
Epoch: 048/050 | Batch 250/469 | Cost: 13804.6270
Epoch: 048/050 | Batch 300/469 | Cost: 14011.4414
Epoch: 048/050 | Batch 350/469 | Cost: 13902.3428
Epoch: 048/050 | Batch 400/469 | Cost: 13671.2607
Epoch: 048/050 | Batch 450/469 | Cost: 13533.4326
Time elapsed: 7.03 min
Epoch: 049/050 | Batch 000/469 | Cost: 13808.8584
Epoch: 049/050 | Batch 050/469 | Cost: 14385.1328
Epoch: 049/050 | Batch 100/469 | Cost: 13595.7334
Epoch: 049/050 | Batch 150/469 | Cost: 13449.9658
Epoch: 049/050 | Batch 200/469 | Cost: 13782.0635
Epoch: 049/050 | Batch 250/469 | Cost: 13681.0293
Epoch: 049/050 | Batch 300/469 | Cost: 14259.2988
Epoch: 049/050 | Batch 350/469 | Cost: 13350.5176
Epoch: 049/050 | Batch 400/469 | Cost: 12788.5156
Epoch: 049/050 | Batch 450/469 | Cost: 13642.1787
Time elapsed: 7.18 min
Epoch: 050/050 | Batch 000/469 | Cost: 13596.1172
Epoch: 050/050 | Batch 050/469 | Cost: 13988.5371
Epoch: 050/050 | Batch 100/469 | Cost: 14061.5742
Epoch: 050/050 | Batch 150/469 | Cost: 13996.9111
Epoch: 050/050 | Batch 200/469 | Cost: 13628.2070
Epoch: 050/050 | Batch 250/469 | Cost: 13667.3203
Epoch: 050/050 | Batch 300/469 | Cost: 13978.5820
Epoch: 050/050 | Batch 350/469 | Cost: 13589.2910
Epoch: 050/050 | Batch 400/469 | Cost: 13307.0566
Epoch: 050/050 | Batch 450/469 | Cost: 13997.9141
Time elapsed: 7.33 min
Total Training Time: 7.33 min

Evaluation

Reconstruction

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt


##########################
### VISUALIZATION
##########################

n_images = 15
image_width = 28

fig, axes = plt.subplots(nrows=2, ncols=n_images, 
                         sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images, 0]

for i in range(n_images):
    for ax, img in zip(axes, [orig_images, decoded_images]):
        cpu_img = img[i].detach().to(torch.device('cpu'))
        ax[i].imshow(cpu_img.view((image_width, image_width)), cmap='binary')

New random-conditional images

In [7]:
for i in range(10):

    ##########################
    ### RANDOM SAMPLE
    ##########################    
    
    labels = torch.tensor([i]*10).to(device)
    n_images = labels.size()[0]
    rand_features = torch.randn(n_images, num_latent).to(device)
    new_images = model.decoder(rand_features, labels)

    ##########################
    ### VISUALIZATION
    ##########################

    image_width = 28

    fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
    decoded_images = new_images[:n_images, 0]

    print('Class Label %d' % i)

    for ax, img in zip(axes, decoded_images):
        cpu_img = img.detach().to(torch.device('cpu'))
        ax.imshow(cpu_img.view((image_width, image_width)), cmap='binary')
        
    plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
In [8]:
watermark -iv
numpy       1.15.4
torch       1.0.0