Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka 

CPython 3.6.8
IPython 7.2.0

torch 1.0.0
  • Runs on CPU or GPU (if available)

Model Zoo -- Convolutional Conditional Variational Autoencoder

(with labels in reconstruction loss)

A simple convolutional conditional variational autoencoder that compresses 768-pixel MNIST images down to a 50-pixel latent vector representation.

This implementation concatenates the inputs with the class labels when computing the reconstruction loss as it is commonly done in non-convolutional conditional variational autoencoders. This leads to substantially poorer results compared to the implementation that does NOT concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cnn-cvae_no-out-concat.ipynb

Imports

In [2]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
In [3]:
##########################
### SETTINGS
##########################

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128

# Architecture
num_classes = 10
num_features = 784
num_latent = 50


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break
Device: cuda:0
Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])

Model

In [4]:
##########################
### MODEL
##########################


def to_onehot(labels, num_classes, device):

    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)

    return labels_onehot


class ConditionalVariationalAutoencoder(torch.nn.Module):

    def __init__(self, num_features, num_latent, num_classes):
        super(ConditionalVariationalAutoencoder, self).__init__()
        
        self.num_classes = num_classes
        
        
        ###############
        # ENCODER
        ##############
        
        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.enc_conv_1 = torch.nn.Conv2d(in_channels=1+self.num_classes,
                                          out_channels=16,
                                          kernel_size=(6, 6),
                                          stride=(2, 2),
                                          padding=0)

        self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,
                                          out_channels=32,
                                          kernel_size=(4, 4),
                                          stride=(2, 2),
                                          padding=0)                 
        
        self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,
                                          out_channels=64,
                                          kernel_size=(2, 2),
                                          stride=(2, 2),
                                          padding=0)                     
        
        self.z_mean = torch.nn.Linear(64*2*2, num_latent)
        # in the original paper (Kingma & Welling 2015, we use
        # have a z_mean and z_var, but the problem is that
        # the z_var can be negative, which would cause issues
        # in the log later. Hence we assume that latent vector
        # has a z_mean and z_log_var component, and when we need
        # the regular variance or std_dev, we simply use 
        # an exponential function
        self.z_log_var = torch.nn.Linear(64*2*2, num_latent)
        
        
        
        ###############
        # DECODER
        ##############
        
        self.dec_linear_1 = torch.nn.Linear(num_latent+self.num_classes, 64*2*2)
               
        self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,
                                                     out_channels=32,
                                                     kernel_size=(2, 2),
                                                     stride=(2, 2),
                                                     padding=0)
                                 
        self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,
                                                     out_channels=16,
                                                     kernel_size=(4, 4),
                                                     stride=(3, 3),
                                                     padding=1)
        
        self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,
                                                     out_channels=11,
                                                     kernel_size=(6, 6),
                                                     stride=(3, 3),
                                                     padding=4)        


    def reparameterize(self, z_mu, z_log_var):
        # Sample epsilon from standard normal distribution
        eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
        # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
        # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z
        
    def encoder(self, features, targets):
        
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        onehot_targets = onehot_targets.view(-1, self.num_classes, 1, 1)
        
        ones = torch.ones(features.size()[0], 
                          self.num_classes,
                          features.size()[2], 
                          features.size()[3], 
                          dtype=features.dtype).to(device)
        ones = ones * onehot_targets
        x = torch.cat((features, ones), dim=1)
        
        x = self.enc_conv_1(x)
        x = F.leaky_relu(x)
        #print('conv1 out:', x.size())
        
        x = self.enc_conv_2(x)
        x = F.leaky_relu(x)
        #print('conv2 out:', x.size())
        
        x = self.enc_conv_3(x)
        x = F.leaky_relu(x)
        #print('conv3 out:', x.size())
        
        z_mean = self.z_mean(x.view(-1, 64*2*2))
        z_log_var = self.z_log_var(x.view(-1, 64*2*2))
        encoded = self.reparameterize(z_mean, z_log_var)
        return z_mean, z_log_var, encoded
    
    def decoder(self, encoded, targets):
        ### Add condition
        onehot_targets = to_onehot(targets, self.num_classes, device)
        encoded = torch.cat((encoded, onehot_targets), dim=1)        
        
        x = self.dec_linear_1(encoded)
        x = x.view(-1, 64, 2, 2)
        
        x = self.dec_deconv_1(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        x = self.dec_deconv_2(x)
        x = F.leaky_relu(x)
        #print('deconv2 out:', x.size())
        
        x = self.dec_deconv_3(x)
        x = F.leaky_relu(x)
        #print('deconv1 out:', x.size())
        
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, features, targets):
        
        z_mean, z_log_var, encoded = self.encoder(features, targets)
        decoded = self.decoder(encoded, targets)
        
        return z_mean, z_log_var, encoded, decoded

    
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
                                          num_latent,
                                          num_classes)
model = model.to(device)
    

##########################
### COST AND OPTIMIZER
##########################

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Training

In [5]:
start_time = time.time()

for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(device)
        targets = targets.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features, targets)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        
        
        ### Add condition
        onehot_targets = to_onehot(targets, num_classes, device)
        onehot_targets = onehot_targets.view(-1, num_classes, 1, 1)
        
        ones = torch.ones(features.size()[0], 
                          num_classes,
                          features.size()[2], 
                          features.size()[3], 
                          dtype=features.dtype).to(device)
        ones = ones * onehot_targets
        x_con = torch.cat((features, ones), dim=1)
        
        
        ### Compute loss
        pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        ### UPDATE MODEL PARAMETERS
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader), cost))
            
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 767976.6875
Epoch: 001/050 | Batch 050/469 | Cost: 596610.6250
Epoch: 001/050 | Batch 100/469 | Cost: 373679.2188
Epoch: 001/050 | Batch 150/469 | Cost: 369423.0312
Epoch: 001/050 | Batch 200/469 | Cost: 372999.6250
Epoch: 001/050 | Batch 250/469 | Cost: 369966.6562
Epoch: 001/050 | Batch 300/469 | Cost: 370309.6250
Epoch: 001/050 | Batch 350/469 | Cost: 368235.9375
Epoch: 001/050 | Batch 400/469 | Cost: 369051.5312
Epoch: 001/050 | Batch 450/469 | Cost: 369094.0000
Time elapsed: 0.16 min
Epoch: 002/050 | Batch 000/469 | Cost: 368791.1875
Epoch: 002/050 | Batch 050/469 | Cost: 367759.2500
Epoch: 002/050 | Batch 100/469 | Cost: 370258.9375
Epoch: 002/050 | Batch 150/469 | Cost: 367393.0938
Epoch: 002/050 | Batch 200/469 | Cost: 367319.2812
Epoch: 002/050 | Batch 250/469 | Cost: 366416.8125
Epoch: 002/050 | Batch 300/469 | Cost: 366461.8125
Epoch: 002/050 | Batch 350/469 | Cost: 369242.8125
Epoch: 002/050 | Batch 400/469 | Cost: 364756.9688
Epoch: 002/050 | Batch 450/469 | Cost: 366672.1562
Time elapsed: 0.31 min
Epoch: 003/050 | Batch 000/469 | Cost: 365496.8750
Epoch: 003/050 | Batch 050/469 | Cost: 364285.8125
Epoch: 003/050 | Batch 100/469 | Cost: 360785.4062
Epoch: 003/050 | Batch 150/469 | Cost: 359350.1562
Epoch: 003/050 | Batch 200/469 | Cost: 360445.7500
Epoch: 003/050 | Batch 250/469 | Cost: 358670.5938
Epoch: 003/050 | Batch 300/469 | Cost: 357442.0938
Epoch: 003/050 | Batch 350/469 | Cost: 355264.8438
Epoch: 003/050 | Batch 400/469 | Cost: 354287.0000
Epoch: 003/050 | Batch 450/469 | Cost: 355108.7500
Time elapsed: 0.47 min
Epoch: 004/050 | Batch 000/469 | Cost: 356378.7812
Epoch: 004/050 | Batch 050/469 | Cost: 353450.8125
Epoch: 004/050 | Batch 100/469 | Cost: 355565.7812
Epoch: 004/050 | Batch 150/469 | Cost: 352293.7188
Epoch: 004/050 | Batch 200/469 | Cost: 352507.3438
Epoch: 004/050 | Batch 250/469 | Cost: 349660.2812
Epoch: 004/050 | Batch 300/469 | Cost: 325723.5312
Epoch: 004/050 | Batch 350/469 | Cost: 321628.3438
Epoch: 004/050 | Batch 400/469 | Cost: 299635.0000
Epoch: 004/050 | Batch 450/469 | Cost: 254649.7344
Time elapsed: 0.63 min
Epoch: 005/050 | Batch 000/469 | Cost: 246473.9219
Epoch: 005/050 | Batch 050/469 | Cost: 253253.3438
Epoch: 005/050 | Batch 100/469 | Cost: 229721.2812
Epoch: 005/050 | Batch 150/469 | Cost: 194743.3750
Epoch: 005/050 | Batch 200/469 | Cost: 206577.7812
Epoch: 005/050 | Batch 250/469 | Cost: 190684.9531
Epoch: 005/050 | Batch 300/469 | Cost: 205158.5781
Epoch: 005/050 | Batch 350/469 | Cost: 191114.2812
Epoch: 005/050 | Batch 400/469 | Cost: 162304.3125
Epoch: 005/050 | Batch 450/469 | Cost: 121093.5156
Time elapsed: 0.78 min
Epoch: 006/050 | Batch 000/469 | Cost: 114467.9297
Epoch: 006/050 | Batch 050/469 | Cost: 96826.9844
Epoch: 006/050 | Batch 100/469 | Cost: 91189.8281
Epoch: 006/050 | Batch 150/469 | Cost: 69205.9062
Epoch: 006/050 | Batch 200/469 | Cost: 69573.7969
Epoch: 006/050 | Batch 250/469 | Cost: 58245.7578
Epoch: 006/050 | Batch 300/469 | Cost: 52482.0312
Epoch: 006/050 | Batch 350/469 | Cost: 49466.5234
Epoch: 006/050 | Batch 400/469 | Cost: 47806.3906
Epoch: 006/050 | Batch 450/469 | Cost: 47156.6328
Time elapsed: 0.94 min
Epoch: 007/050 | Batch 000/469 | Cost: 46200.7422
Epoch: 007/050 | Batch 050/469 | Cost: 43992.4375
Epoch: 007/050 | Batch 100/469 | Cost: 45221.8672
Epoch: 007/050 | Batch 150/469 | Cost: 44009.6562
Epoch: 007/050 | Batch 200/469 | Cost: 42143.8555
Epoch: 007/050 | Batch 250/469 | Cost: 41778.6367
Epoch: 007/050 | Batch 300/469 | Cost: 39909.2383
Epoch: 007/050 | Batch 350/469 | Cost: 116561.6953
Epoch: 007/050 | Batch 400/469 | Cost: 40465.0156
Epoch: 007/050 | Batch 450/469 | Cost: 39080.0469
Time elapsed: 1.10 min
Epoch: 008/050 | Batch 000/469 | Cost: 37986.4492
Epoch: 008/050 | Batch 050/469 | Cost: 38122.7812
Epoch: 008/050 | Batch 100/469 | Cost: 37492.6172
Epoch: 008/050 | Batch 150/469 | Cost: 36924.3398
Epoch: 008/050 | Batch 200/469 | Cost: 36508.1406
Epoch: 008/050 | Batch 250/469 | Cost: 36034.1172
Epoch: 008/050 | Batch 300/469 | Cost: 35373.1719
Epoch: 008/050 | Batch 350/469 | Cost: 36094.9648
Epoch: 008/050 | Batch 400/469 | Cost: 35194.5547
Epoch: 008/050 | Batch 450/469 | Cost: 35357.0078
Time elapsed: 1.25 min
Epoch: 009/050 | Batch 000/469 | Cost: 34139.3164
Epoch: 009/050 | Batch 050/469 | Cost: 32684.7852
Epoch: 009/050 | Batch 100/469 | Cost: 33087.8008
Epoch: 009/050 | Batch 150/469 | Cost: 34713.3516
Epoch: 009/050 | Batch 200/469 | Cost: 34058.9414
Epoch: 009/050 | Batch 250/469 | Cost: 34817.7734
Epoch: 009/050 | Batch 300/469 | Cost: 32921.8438
Epoch: 009/050 | Batch 350/469 | Cost: 32419.2344
Epoch: 009/050 | Batch 400/469 | Cost: 32464.9297
Epoch: 009/050 | Batch 450/469 | Cost: 33157.8555
Time elapsed: 1.41 min
Epoch: 010/050 | Batch 000/469 | Cost: 33069.7344
Epoch: 010/050 | Batch 050/469 | Cost: 32903.2422
Epoch: 010/050 | Batch 100/469 | Cost: 33106.3438
Epoch: 010/050 | Batch 150/469 | Cost: 31759.2578
Epoch: 010/050 | Batch 200/469 | Cost: 32032.4941
Epoch: 010/050 | Batch 250/469 | Cost: 30970.5547
Epoch: 010/050 | Batch 300/469 | Cost: 32380.6895
Epoch: 010/050 | Batch 350/469 | Cost: 31950.4258
Epoch: 010/050 | Batch 400/469 | Cost: 29926.0078
Epoch: 010/050 | Batch 450/469 | Cost: 30842.2812
Time elapsed: 1.57 min
Epoch: 011/050 | Batch 000/469 | Cost: 30815.4336
Epoch: 011/050 | Batch 050/469 | Cost: 33126.0547
Epoch: 011/050 | Batch 100/469 | Cost: 32505.0547
Epoch: 011/050 | Batch 150/469 | Cost: 30127.3398
Epoch: 011/050 | Batch 200/469 | Cost: 30061.9590
Epoch: 011/050 | Batch 250/469 | Cost: 30604.7930
Epoch: 011/050 | Batch 300/469 | Cost: 29668.6602
Epoch: 011/050 | Batch 350/469 | Cost: 30926.6719
Epoch: 011/050 | Batch 400/469 | Cost: 30350.8242
Epoch: 011/050 | Batch 450/469 | Cost: 30159.5273
Time elapsed: 1.73 min
Epoch: 012/050 | Batch 000/469 | Cost: 30400.7031
Epoch: 012/050 | Batch 050/469 | Cost: 30232.4375
Epoch: 012/050 | Batch 100/469 | Cost: 29833.6797
Epoch: 012/050 | Batch 150/469 | Cost: 30118.9746
Epoch: 012/050 | Batch 200/469 | Cost: 30293.8867
Epoch: 012/050 | Batch 250/469 | Cost: 29551.5000
Epoch: 012/050 | Batch 300/469 | Cost: 30271.3555
Epoch: 012/050 | Batch 350/469 | Cost: 29931.2773
Epoch: 012/050 | Batch 400/469 | Cost: 29358.0820
Epoch: 012/050 | Batch 450/469 | Cost: 30744.5586
Time elapsed: 1.88 min
Epoch: 013/050 | Batch 000/469 | Cost: 29265.8965
Epoch: 013/050 | Batch 050/469 | Cost: 27530.9746
Epoch: 013/050 | Batch 100/469 | Cost: 28721.8672
Epoch: 013/050 | Batch 150/469 | Cost: 30155.9512
Epoch: 013/050 | Batch 200/469 | Cost: 29030.8125
Epoch: 013/050 | Batch 250/469 | Cost: 29201.0938
Epoch: 013/050 | Batch 300/469 | Cost: 28295.7402
Epoch: 013/050 | Batch 350/469 | Cost: 28157.3828
Epoch: 013/050 | Batch 400/469 | Cost: 28565.1133
Epoch: 013/050 | Batch 450/469 | Cost: 28998.5449
Time elapsed: 2.04 min
Epoch: 014/050 | Batch 000/469 | Cost: 28962.7031
Epoch: 014/050 | Batch 050/469 | Cost: 28651.3438
Epoch: 014/050 | Batch 100/469 | Cost: 29397.2930
Epoch: 014/050 | Batch 150/469 | Cost: 29416.5078
Epoch: 014/050 | Batch 200/469 | Cost: 28977.9805
Epoch: 014/050 | Batch 250/469 | Cost: 29161.6523
Epoch: 014/050 | Batch 300/469 | Cost: 28904.8867
Epoch: 014/050 | Batch 350/469 | Cost: 26424.5078
Epoch: 014/050 | Batch 400/469 | Cost: 27135.6367
Epoch: 014/050 | Batch 450/469 | Cost: 27612.0020
Time elapsed: 2.20 min
Epoch: 015/050 | Batch 000/469 | Cost: 28140.8086
Epoch: 015/050 | Batch 050/469 | Cost: 29116.8887
Epoch: 015/050 | Batch 100/469 | Cost: 28442.5781
Epoch: 015/050 | Batch 150/469 | Cost: 28238.6250
Epoch: 015/050 | Batch 200/469 | Cost: 27482.3203
Epoch: 015/050 | Batch 250/469 | Cost: 28634.3496
Epoch: 015/050 | Batch 300/469 | Cost: 26978.4004
Epoch: 015/050 | Batch 350/469 | Cost: 29071.4707
Epoch: 015/050 | Batch 400/469 | Cost: 27510.0801
Epoch: 015/050 | Batch 450/469 | Cost: 27338.2227
Time elapsed: 2.36 min
Epoch: 016/050 | Batch 000/469 | Cost: 26819.6582
Epoch: 016/050 | Batch 050/469 | Cost: 28451.7012
Epoch: 016/050 | Batch 100/469 | Cost: 30336.4238
Epoch: 016/050 | Batch 150/469 | Cost: 27339.8887
Epoch: 016/050 | Batch 200/469 | Cost: 27272.3281
Epoch: 016/050 | Batch 250/469 | Cost: 26689.3438
Epoch: 016/050 | Batch 300/469 | Cost: 27881.7383
Epoch: 016/050 | Batch 350/469 | Cost: 27540.2598
Epoch: 016/050 | Batch 400/469 | Cost: 28315.9961
Epoch: 016/050 | Batch 450/469 | Cost: 27461.7090
Time elapsed: 2.51 min
Epoch: 017/050 | Batch 000/469 | Cost: 26608.0723
Epoch: 017/050 | Batch 050/469 | Cost: 28546.3730
Epoch: 017/050 | Batch 100/469 | Cost: 26661.4219
Epoch: 017/050 | Batch 150/469 | Cost: 27780.3164
Epoch: 017/050 | Batch 200/469 | Cost: 27309.1055
Epoch: 017/050 | Batch 250/469 | Cost: 27510.1016
Epoch: 017/050 | Batch 300/469 | Cost: 27377.1172
Epoch: 017/050 | Batch 350/469 | Cost: 27471.3945
Epoch: 017/050 | Batch 400/469 | Cost: 30324.5586
Epoch: 017/050 | Batch 450/469 | Cost: 73500.2812
Time elapsed: 2.67 min
Epoch: 018/050 | Batch 000/469 | Cost: 42388.4688
Epoch: 018/050 | Batch 050/469 | Cost: 27849.8965
Epoch: 018/050 | Batch 100/469 | Cost: 27593.3828
Epoch: 018/050 | Batch 150/469 | Cost: 27052.1992
Epoch: 018/050 | Batch 200/469 | Cost: 27993.3262
Epoch: 018/050 | Batch 250/469 | Cost: 27613.9766
Epoch: 018/050 | Batch 300/469 | Cost: 26225.6758
Epoch: 018/050 | Batch 350/469 | Cost: 26985.5664
Epoch: 018/050 | Batch 400/469 | Cost: 27585.4297
Epoch: 018/050 | Batch 450/469 | Cost: 27168.5215
Time elapsed: 2.83 min
Epoch: 019/050 | Batch 000/469 | Cost: 27300.5352
Epoch: 019/050 | Batch 050/469 | Cost: 26698.8320
Epoch: 019/050 | Batch 100/469 | Cost: 27803.5508
Epoch: 019/050 | Batch 150/469 | Cost: 27134.0371
Epoch: 019/050 | Batch 200/469 | Cost: 25812.2949
Epoch: 019/050 | Batch 250/469 | Cost: 25544.7246
Epoch: 019/050 | Batch 300/469 | Cost: 25584.1641
Epoch: 019/050 | Batch 350/469 | Cost: 27084.9492
Epoch: 019/050 | Batch 400/469 | Cost: 27120.7324
Epoch: 019/050 | Batch 450/469 | Cost: 27272.8809
Time elapsed: 2.98 min
Epoch: 020/050 | Batch 000/469 | Cost: 27053.4453
Epoch: 020/050 | Batch 050/469 | Cost: 27369.2656
Epoch: 020/050 | Batch 100/469 | Cost: 26794.1035
Epoch: 020/050 | Batch 150/469 | Cost: 26575.0957
Epoch: 020/050 | Batch 200/469 | Cost: 27841.3066
Epoch: 020/050 | Batch 250/469 | Cost: 26915.7656
Epoch: 020/050 | Batch 300/469 | Cost: 25307.2305
Epoch: 020/050 | Batch 350/469 | Cost: 26626.6621
Epoch: 020/050 | Batch 400/469 | Cost: 27154.6895
Epoch: 020/050 | Batch 450/469 | Cost: 27053.9160
Time elapsed: 3.14 min
Epoch: 021/050 | Batch 000/469 | Cost: 27955.6387
Epoch: 021/050 | Batch 050/469 | Cost: 27744.2793
Epoch: 021/050 | Batch 100/469 | Cost: 26278.8652
Epoch: 021/050 | Batch 150/469 | Cost: 26567.6895
Epoch: 021/050 | Batch 200/469 | Cost: 26069.6836
Epoch: 021/050 | Batch 250/469 | Cost: 26567.5332
Epoch: 021/050 | Batch 300/469 | Cost: 26253.0352
Epoch: 021/050 | Batch 350/469 | Cost: 27332.9531
Epoch: 021/050 | Batch 400/469 | Cost: 26668.5371
Epoch: 021/050 | Batch 450/469 | Cost: 25928.0664
Time elapsed: 3.30 min
Epoch: 022/050 | Batch 000/469 | Cost: 26090.3027
Epoch: 022/050 | Batch 050/469 | Cost: 26201.6992
Epoch: 022/050 | Batch 100/469 | Cost: 26070.0645
Epoch: 022/050 | Batch 150/469 | Cost: 27712.2090
Epoch: 022/050 | Batch 200/469 | Cost: 27064.6074
Epoch: 022/050 | Batch 250/469 | Cost: 25590.2988
Epoch: 022/050 | Batch 300/469 | Cost: 26056.4648
Epoch: 022/050 | Batch 350/469 | Cost: 25239.4082
Epoch: 022/050 | Batch 400/469 | Cost: 27711.3926
Epoch: 022/050 | Batch 450/469 | Cost: 26541.3496
Time elapsed: 3.46 min
Epoch: 023/050 | Batch 000/469 | Cost: 26866.8457
Epoch: 023/050 | Batch 050/469 | Cost: 26516.2324
Epoch: 023/050 | Batch 100/469 | Cost: 27288.9688
Epoch: 023/050 | Batch 150/469 | Cost: 26922.4766
Epoch: 023/050 | Batch 200/469 | Cost: 26217.9844
Epoch: 023/050 | Batch 250/469 | Cost: 26235.7891
Epoch: 023/050 | Batch 300/469 | Cost: 26025.4102
Epoch: 023/050 | Batch 350/469 | Cost: 26741.8125
Epoch: 023/050 | Batch 400/469 | Cost: 26891.1172
Epoch: 023/050 | Batch 450/469 | Cost: 25617.7754
Time elapsed: 3.61 min
Epoch: 024/050 | Batch 000/469 | Cost: 26568.8223
Epoch: 024/050 | Batch 050/469 | Cost: 25969.9336
Epoch: 024/050 | Batch 100/469 | Cost: 27559.0918
Epoch: 024/050 | Batch 150/469 | Cost: 27023.5234
Epoch: 024/050 | Batch 200/469 | Cost: 25339.0430
Epoch: 024/050 | Batch 250/469 | Cost: 26641.5664
Epoch: 024/050 | Batch 300/469 | Cost: 26808.2676
Epoch: 024/050 | Batch 350/469 | Cost: 25360.6113
Epoch: 024/050 | Batch 400/469 | Cost: 25561.7832
Epoch: 024/050 | Batch 450/469 | Cost: 25617.8594
Time elapsed: 3.77 min
Epoch: 025/050 | Batch 000/469 | Cost: 26287.0703
Epoch: 025/050 | Batch 050/469 | Cost: 25861.7852
Epoch: 025/050 | Batch 100/469 | Cost: 27369.2188
Epoch: 025/050 | Batch 150/469 | Cost: 26051.3398
Epoch: 025/050 | Batch 200/469 | Cost: 25579.8047
Epoch: 025/050 | Batch 250/469 | Cost: 25831.1133
Epoch: 025/050 | Batch 300/469 | Cost: 25919.6426
Epoch: 025/050 | Batch 350/469 | Cost: 26633.8320
Epoch: 025/050 | Batch 400/469 | Cost: 26041.8867
Epoch: 025/050 | Batch 450/469 | Cost: 24800.4590
Time elapsed: 3.93 min
Epoch: 026/050 | Batch 000/469 | Cost: 25229.7207
Epoch: 026/050 | Batch 050/469 | Cost: 25182.7578
Epoch: 026/050 | Batch 100/469 | Cost: 25985.1465
Epoch: 026/050 | Batch 150/469 | Cost: 25212.9902
Epoch: 026/050 | Batch 200/469 | Cost: 26574.4395
Epoch: 026/050 | Batch 250/469 | Cost: 25608.6055
Epoch: 026/050 | Batch 300/469 | Cost: 103743.6641
Epoch: 026/050 | Batch 350/469 | Cost: 29513.8047
Epoch: 026/050 | Batch 400/469 | Cost: 27109.1582
Epoch: 026/050 | Batch 450/469 | Cost: 26289.1582
Time elapsed: 4.08 min
Epoch: 027/050 | Batch 000/469 | Cost: 27039.6719
Epoch: 027/050 | Batch 050/469 | Cost: 26615.5312
Epoch: 027/050 | Batch 100/469 | Cost: 25575.8086
Epoch: 027/050 | Batch 150/469 | Cost: 27047.0000
Epoch: 027/050 | Batch 200/469 | Cost: 26100.5879
Epoch: 027/050 | Batch 250/469 | Cost: 26004.3633
Epoch: 027/050 | Batch 300/469 | Cost: 27638.7969
Epoch: 027/050 | Batch 350/469 | Cost: 26723.9863
Epoch: 027/050 | Batch 400/469 | Cost: 25743.5000
Epoch: 027/050 | Batch 450/469 | Cost: 25969.0234
Time elapsed: 4.24 min
Epoch: 028/050 | Batch 000/469 | Cost: 24330.8633
Epoch: 028/050 | Batch 050/469 | Cost: 25649.6270
Epoch: 028/050 | Batch 100/469 | Cost: 25380.6094
Epoch: 028/050 | Batch 150/469 | Cost: 25158.7676
Epoch: 028/050 | Batch 200/469 | Cost: 24840.2871
Epoch: 028/050 | Batch 250/469 | Cost: 25271.3105
Epoch: 028/050 | Batch 300/469 | Cost: 24288.1465
Epoch: 028/050 | Batch 350/469 | Cost: 24853.2852
Epoch: 028/050 | Batch 400/469 | Cost: 26212.7070
Epoch: 028/050 | Batch 450/469 | Cost: 26409.6934
Time elapsed: 4.40 min
Epoch: 029/050 | Batch 000/469 | Cost: 26000.0156
Epoch: 029/050 | Batch 050/469 | Cost: 25050.6719
Epoch: 029/050 | Batch 100/469 | Cost: 25016.0645
Epoch: 029/050 | Batch 150/469 | Cost: 25192.9238
Epoch: 029/050 | Batch 200/469 | Cost: 25538.5840
Epoch: 029/050 | Batch 250/469 | Cost: 26724.6504
Epoch: 029/050 | Batch 300/469 | Cost: 26444.7988
Epoch: 029/050 | Batch 350/469 | Cost: 25800.6934
Epoch: 029/050 | Batch 400/469 | Cost: 24840.4258
Epoch: 029/050 | Batch 450/469 | Cost: 26121.0117
Time elapsed: 4.55 min
Epoch: 030/050 | Batch 000/469 | Cost: 25889.4258
Epoch: 030/050 | Batch 050/469 | Cost: 24564.8105
Epoch: 030/050 | Batch 100/469 | Cost: 25197.3223
Epoch: 030/050 | Batch 150/469 | Cost: 25926.4180
Epoch: 030/050 | Batch 200/469 | Cost: 24823.6055
Epoch: 030/050 | Batch 250/469 | Cost: 24570.4961
Epoch: 030/050 | Batch 300/469 | Cost: 26450.1582
Epoch: 030/050 | Batch 350/469 | Cost: 25991.0820
Epoch: 030/050 | Batch 400/469 | Cost: 27038.8398
Epoch: 030/050 | Batch 450/469 | Cost: 25465.6094
Time elapsed: 4.71 min
Epoch: 031/050 | Batch 000/469 | Cost: 25552.3574
Epoch: 031/050 | Batch 050/469 | Cost: 25650.0840
Epoch: 031/050 | Batch 100/469 | Cost: 25189.4062
Epoch: 031/050 | Batch 150/469 | Cost: 26423.2188
Epoch: 031/050 | Batch 200/469 | Cost: 24858.3926
Epoch: 031/050 | Batch 250/469 | Cost: 26807.5215
Epoch: 031/050 | Batch 300/469 | Cost: 26289.6484
Epoch: 031/050 | Batch 350/469 | Cost: 26251.7109
Epoch: 031/050 | Batch 400/469 | Cost: 25341.6426
Epoch: 031/050 | Batch 450/469 | Cost: 25598.0586
Time elapsed: 4.87 min
Epoch: 032/050 | Batch 000/469 | Cost: 25554.5430
Epoch: 032/050 | Batch 050/469 | Cost: 27328.4414
Epoch: 032/050 | Batch 100/469 | Cost: 25416.3203
Epoch: 032/050 | Batch 150/469 | Cost: 26040.9531
Epoch: 032/050 | Batch 200/469 | Cost: 25655.6426
Epoch: 032/050 | Batch 250/469 | Cost: 26179.0469
Epoch: 032/050 | Batch 300/469 | Cost: 25275.5391
Epoch: 032/050 | Batch 350/469 | Cost: 24778.6836
Epoch: 032/050 | Batch 400/469 | Cost: 25070.4062
Epoch: 032/050 | Batch 450/469 | Cost: 25324.0215
Time elapsed: 5.02 min
Epoch: 033/050 | Batch 000/469 | Cost: 24642.3848
Epoch: 033/050 | Batch 050/469 | Cost: 24271.6816
Epoch: 033/050 | Batch 100/469 | Cost: 25492.1836
Epoch: 033/050 | Batch 150/469 | Cost: 25345.7363
Epoch: 033/050 | Batch 200/469 | Cost: 25483.8418
Epoch: 033/050 | Batch 250/469 | Cost: 25023.7363
Epoch: 033/050 | Batch 300/469 | Cost: 24663.6562
Epoch: 033/050 | Batch 350/469 | Cost: 25393.5957
Epoch: 033/050 | Batch 400/469 | Cost: 25995.7422
Epoch: 033/050 | Batch 450/469 | Cost: 25259.6289
Time elapsed: 5.18 min
Epoch: 034/050 | Batch 000/469 | Cost: 22940.1719
Epoch: 034/050 | Batch 050/469 | Cost: 24678.4434
Epoch: 034/050 | Batch 100/469 | Cost: 24350.1992
Epoch: 034/050 | Batch 150/469 | Cost: 24904.4062
Epoch: 034/050 | Batch 200/469 | Cost: 26084.8457
Epoch: 034/050 | Batch 250/469 | Cost: 25387.4746
Epoch: 034/050 | Batch 300/469 | Cost: 25201.3711
Epoch: 034/050 | Batch 350/469 | Cost: 26104.6172
Epoch: 034/050 | Batch 400/469 | Cost: 25046.2305
Epoch: 034/050 | Batch 450/469 | Cost: 24662.4102
Time elapsed: 5.34 min
Epoch: 035/050 | Batch 000/469 | Cost: 24960.4844
Epoch: 035/050 | Batch 050/469 | Cost: 24758.6230
Epoch: 035/050 | Batch 100/469 | Cost: 25567.4766
Epoch: 035/050 | Batch 150/469 | Cost: 26101.2344
Epoch: 035/050 | Batch 200/469 | Cost: 37525.5625
Epoch: 035/050 | Batch 250/469 | Cost: 25834.8398
Epoch: 035/050 | Batch 300/469 | Cost: 25407.0664
Epoch: 035/050 | Batch 350/469 | Cost: 24773.7715
Epoch: 035/050 | Batch 400/469 | Cost: 24848.2070
Epoch: 035/050 | Batch 450/469 | Cost: 25452.8770
Time elapsed: 5.50 min
Epoch: 036/050 | Batch 000/469 | Cost: 25342.3438
Epoch: 036/050 | Batch 050/469 | Cost: 24279.4141
Epoch: 036/050 | Batch 100/469 | Cost: 24467.3691
Epoch: 036/050 | Batch 150/469 | Cost: 24398.2832
Epoch: 036/050 | Batch 200/469 | Cost: 24147.9492
Epoch: 036/050 | Batch 250/469 | Cost: 24729.1992
Epoch: 036/050 | Batch 300/469 | Cost: 24632.2969
Epoch: 036/050 | Batch 350/469 | Cost: 24841.6211
Epoch: 036/050 | Batch 400/469 | Cost: 24582.1094
Epoch: 036/050 | Batch 450/469 | Cost: 24980.0625
Time elapsed: 5.65 min
Epoch: 037/050 | Batch 000/469 | Cost: 25203.0020
Epoch: 037/050 | Batch 050/469 | Cost: 23087.3906
Epoch: 037/050 | Batch 100/469 | Cost: 24115.1836
Epoch: 037/050 | Batch 150/469 | Cost: 24795.7891
Epoch: 037/050 | Batch 200/469 | Cost: 27220.8164
Epoch: 037/050 | Batch 250/469 | Cost: 24840.1055
Epoch: 037/050 | Batch 300/469 | Cost: 25391.5098
Epoch: 037/050 | Batch 350/469 | Cost: 25561.5938
Epoch: 037/050 | Batch 400/469 | Cost: 23791.0605
Epoch: 037/050 | Batch 450/469 | Cost: 24261.7539
Time elapsed: 5.81 min
Epoch: 038/050 | Batch 000/469 | Cost: 23855.5293
Epoch: 038/050 | Batch 050/469 | Cost: 25037.2031
Epoch: 038/050 | Batch 100/469 | Cost: 25081.6836
Epoch: 038/050 | Batch 150/469 | Cost: 24726.2656
Epoch: 038/050 | Batch 200/469 | Cost: 26345.6641
Epoch: 038/050 | Batch 250/469 | Cost: 24811.2539
Epoch: 038/050 | Batch 300/469 | Cost: 24353.3047
Epoch: 038/050 | Batch 350/469 | Cost: 25306.9180
Epoch: 038/050 | Batch 400/469 | Cost: 24490.6641
Epoch: 038/050 | Batch 450/469 | Cost: 25235.3613
Time elapsed: 5.97 min
Epoch: 039/050 | Batch 000/469 | Cost: 24276.7832
Epoch: 039/050 | Batch 050/469 | Cost: 24525.7070
Epoch: 039/050 | Batch 100/469 | Cost: 24906.1289
Epoch: 039/050 | Batch 150/469 | Cost: 24968.1094
Epoch: 039/050 | Batch 200/469 | Cost: 24574.9062
Epoch: 039/050 | Batch 250/469 | Cost: 24858.0703
Epoch: 039/050 | Batch 300/469 | Cost: 25797.6152
Epoch: 039/050 | Batch 350/469 | Cost: 23874.2402
Epoch: 039/050 | Batch 400/469 | Cost: 25120.7891
Epoch: 039/050 | Batch 450/469 | Cost: 23778.7520
Time elapsed: 6.13 min
Epoch: 040/050 | Batch 000/469 | Cost: 24705.1719
Epoch: 040/050 | Batch 050/469 | Cost: 24627.0195
Epoch: 040/050 | Batch 100/469 | Cost: 24295.2754
Epoch: 040/050 | Batch 150/469 | Cost: 24087.8906
Epoch: 040/050 | Batch 200/469 | Cost: 25491.7715
Epoch: 040/050 | Batch 250/469 | Cost: 24501.0703
Epoch: 040/050 | Batch 300/469 | Cost: 26422.9824
Epoch: 040/050 | Batch 350/469 | Cost: 25514.8086
Epoch: 040/050 | Batch 400/469 | Cost: 25690.4043
Epoch: 040/050 | Batch 450/469 | Cost: 24029.9238
Time elapsed: 6.28 min
Epoch: 041/050 | Batch 000/469 | Cost: 24140.9375
Epoch: 041/050 | Batch 050/469 | Cost: 24123.9629
Epoch: 041/050 | Batch 100/469 | Cost: 24918.5645
Epoch: 041/050 | Batch 150/469 | Cost: 24718.9492
Epoch: 041/050 | Batch 200/469 | Cost: 24464.7383
Epoch: 041/050 | Batch 250/469 | Cost: 23528.3867
Epoch: 041/050 | Batch 300/469 | Cost: 24874.0156
Epoch: 041/050 | Batch 350/469 | Cost: 24976.7266
Epoch: 041/050 | Batch 400/469 | Cost: 24297.3750
Epoch: 041/050 | Batch 450/469 | Cost: 24892.3906
Time elapsed: 6.44 min
Epoch: 042/050 | Batch 000/469 | Cost: 23911.4434
Epoch: 042/050 | Batch 050/469 | Cost: 24480.0840
Epoch: 042/050 | Batch 100/469 | Cost: 24132.7617
Epoch: 042/050 | Batch 150/469 | Cost: 26151.5430
Epoch: 042/050 | Batch 200/469 | Cost: 24691.4297
Epoch: 042/050 | Batch 250/469 | Cost: 32332.8027
Epoch: 042/050 | Batch 300/469 | Cost: 26727.7930
Epoch: 042/050 | Batch 350/469 | Cost: 24738.1543
Epoch: 042/050 | Batch 400/469 | Cost: 25356.9297
Epoch: 042/050 | Batch 450/469 | Cost: 25567.9824
Time elapsed: 6.60 min
Epoch: 043/050 | Batch 000/469 | Cost: 24626.9219
Epoch: 043/050 | Batch 050/469 | Cost: 24583.6211
Epoch: 043/050 | Batch 100/469 | Cost: 23771.1250
Epoch: 043/050 | Batch 150/469 | Cost: 23740.8223
Epoch: 043/050 | Batch 200/469 | Cost: 25974.3535
Epoch: 043/050 | Batch 250/469 | Cost: 24506.7715
Epoch: 043/050 | Batch 300/469 | Cost: 24850.9629
Epoch: 043/050 | Batch 350/469 | Cost: 23420.8887
Epoch: 043/050 | Batch 400/469 | Cost: 23890.6582
Epoch: 043/050 | Batch 450/469 | Cost: 24406.9375
Time elapsed: 6.76 min
Epoch: 044/050 | Batch 000/469 | Cost: 24515.1172
Epoch: 044/050 | Batch 050/469 | Cost: 24865.0742
Epoch: 044/050 | Batch 100/469 | Cost: 24439.4609
Epoch: 044/050 | Batch 150/469 | Cost: 24490.3047
Epoch: 044/050 | Batch 200/469 | Cost: 23753.9219
Epoch: 044/050 | Batch 250/469 | Cost: 23811.8926
Epoch: 044/050 | Batch 300/469 | Cost: 24070.1172
Epoch: 044/050 | Batch 350/469 | Cost: 24404.0664
Epoch: 044/050 | Batch 400/469 | Cost: 25219.6699
Epoch: 044/050 | Batch 450/469 | Cost: 23585.7500
Time elapsed: 6.91 min
Epoch: 045/050 | Batch 000/469 | Cost: 23822.3262
Epoch: 045/050 | Batch 050/469 | Cost: 23653.2695
Epoch: 045/050 | Batch 100/469 | Cost: 25814.4062
Epoch: 045/050 | Batch 150/469 | Cost: 23872.3867
Epoch: 045/050 | Batch 200/469 | Cost: 25231.3008
Epoch: 045/050 | Batch 250/469 | Cost: 24211.3652
Epoch: 045/050 | Batch 300/469 | Cost: 24564.8242
Epoch: 045/050 | Batch 350/469 | Cost: 23450.6211
Epoch: 045/050 | Batch 400/469 | Cost: 24501.6504
Epoch: 045/050 | Batch 450/469 | Cost: 26215.8633
Time elapsed: 7.07 min
Epoch: 046/050 | Batch 000/469 | Cost: 24400.6562
Epoch: 046/050 | Batch 050/469 | Cost: 24448.3691
Epoch: 046/050 | Batch 100/469 | Cost: 24466.0859
Epoch: 046/050 | Batch 150/469 | Cost: 24153.8711
Epoch: 046/050 | Batch 200/469 | Cost: 24351.0098
Epoch: 046/050 | Batch 250/469 | Cost: 23123.2500
Epoch: 046/050 | Batch 300/469 | Cost: 24734.2773
Epoch: 046/050 | Batch 350/469 | Cost: 23785.1875
Epoch: 046/050 | Batch 400/469 | Cost: 24901.5039
Epoch: 046/050 | Batch 450/469 | Cost: 23700.1133
Time elapsed: 7.22 min
Epoch: 047/050 | Batch 000/469 | Cost: 25294.2520
Epoch: 047/050 | Batch 050/469 | Cost: 24074.6992
Epoch: 047/050 | Batch 100/469 | Cost: 24112.8848
Epoch: 047/050 | Batch 150/469 | Cost: 24861.8926
Epoch: 047/050 | Batch 200/469 | Cost: 22852.8594
Epoch: 047/050 | Batch 250/469 | Cost: 23799.0703
Epoch: 047/050 | Batch 300/469 | Cost: 23758.0039
Epoch: 047/050 | Batch 350/469 | Cost: 23628.5391
Epoch: 047/050 | Batch 400/469 | Cost: 23933.1504
Epoch: 047/050 | Batch 450/469 | Cost: 22900.7715
Time elapsed: 7.38 min
Epoch: 048/050 | Batch 000/469 | Cost: 23949.8223
Epoch: 048/050 | Batch 050/469 | Cost: 24267.9609
Epoch: 048/050 | Batch 100/469 | Cost: 22838.5234
Epoch: 048/050 | Batch 150/469 | Cost: 24212.3223
Epoch: 048/050 | Batch 200/469 | Cost: 23809.5449
Epoch: 048/050 | Batch 250/469 | Cost: 23827.1680
Epoch: 048/050 | Batch 300/469 | Cost: 25127.4844
Epoch: 048/050 | Batch 350/469 | Cost: 23184.9473
Epoch: 048/050 | Batch 400/469 | Cost: 24065.0840
Epoch: 048/050 | Batch 450/469 | Cost: 23201.5645
Time elapsed: 7.53 min
Epoch: 049/050 | Batch 000/469 | Cost: 23682.0781
Epoch: 049/050 | Batch 050/469 | Cost: 23740.3887
Epoch: 049/050 | Batch 100/469 | Cost: 23290.7441
Epoch: 049/050 | Batch 150/469 | Cost: 23001.3262
Epoch: 049/050 | Batch 200/469 | Cost: 23265.8105
Epoch: 049/050 | Batch 250/469 | Cost: 22163.1328
Epoch: 049/050 | Batch 300/469 | Cost: 24283.0508
Epoch: 049/050 | Batch 350/469 | Cost: 23822.0098
Epoch: 049/050 | Batch 400/469 | Cost: 22784.8594
Epoch: 049/050 | Batch 450/469 | Cost: 24202.4961
Time elapsed: 7.69 min
Epoch: 050/050 | Batch 000/469 | Cost: 23966.5840
Epoch: 050/050 | Batch 050/469 | Cost: 24665.5449
Epoch: 050/050 | Batch 100/469 | Cost: 23895.6406
Epoch: 050/050 | Batch 150/469 | Cost: 24318.3926
Epoch: 050/050 | Batch 200/469 | Cost: 23685.9727
Epoch: 050/050 | Batch 250/469 | Cost: 23648.9336
Epoch: 050/050 | Batch 300/469 | Cost: 23634.2500
Epoch: 050/050 | Batch 350/469 | Cost: 27888.9062
Epoch: 050/050 | Batch 400/469 | Cost: 23649.8242
Epoch: 050/050 | Batch 450/469 | Cost: 22728.7891
Time elapsed: 7.85 min
Total Training Time: 7.85 min

Evaluation

Reconstruction

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt

##########################
### VISUALIZATION
##########################

n_images = 15
image_width = 28

fig, axes = plt.subplots(nrows=2, ncols=n_images, 
                         sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images, 0]

for i in range(n_images):
    for ax, img in zip(axes, [orig_images, decoded_images]):
        ax[i].imshow(img[i].detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')

New random-conditional images

In [7]:
for i in range(10):

    ##########################
    ### RANDOM SAMPLE
    ##########################    
    
    labels = torch.tensor([i]*10).to(device)
    n_images = labels.size()[0]
    rand_features = torch.randn(n_images, num_latent).to(device)
    new_images = model.decoder(rand_features, labels)

    ##########################
    ### VISUALIZATION
    ##########################

    image_width = 28

    fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
    decoded_images = new_images[:n_images, 0]

    print('Class Label %d' % i)

    for ax, img in zip(axes, decoded_images):
        ax.imshow(img.detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')
        
    plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
In [8]:
%watermark -iv
numpy       1.15.4
torch       1.0.0