Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.6.8 IPython 7.2.0 torch 1.0.0
A simple conditional variational autoencoder that compresses 768-pixel MNIST images down to a 35-pixel latent vector representation.
This implementation concatenates the inputs with the class labels when computing the reconstruction loss as it is commonly done in non-convolutional conditional variational autoencoders. This leads to sightly better results compared to the implementation that does NOT concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cvae_no-out-concat.ipynb
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
##########################
### SETTINGS
##########################
# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)
# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128
# Architecture
num_classes = 10
num_features = 784
num_hidden_1 = 500
num_latent = 35
##########################
### MNIST DATASET
##########################
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Device: cuda:1 Image batch dimensions: torch.Size([128, 1, 28, 28]) Image label dimensions: torch.Size([128])
##########################
### MODEL
##########################
def to_onehot(labels, num_classes, device):
labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
labels_onehot.scatter_(1, labels.view(-1, 1), 1)
return labels_onehot
class ConditionalVariationalAutoencoder(torch.nn.Module):
def __init__(self, num_features, num_hidden_1, num_latent, num_classes):
super(ConditionalVariationalAutoencoder, self).__init__()
self.num_classes = num_classes
### ENCODER
self.hidden_1 = torch.nn.Linear(num_features+num_classes, num_hidden_1)
self.z_mean = torch.nn.Linear(num_hidden_1, num_latent)
# in the original paper (Kingma & Welling 2015, we use
# have a z_mean and z_var, but the problem is that
# the z_var can be negative, which would cause issues
# in the log later. Hence we assume that latent vector
# has a z_mean and z_log_var component, and when we need
# the regular variance or std_dev, we simply use
# an exponential function
self.z_log_var = torch.nn.Linear(num_hidden_1, num_latent)
### DECODER
self.linear_3 = torch.nn.Linear(num_latent+num_classes, num_hidden_1)
self.linear_4 = torch.nn.Linear(num_hidden_1, num_features+num_classes)
def reparameterize(self, z_mu, z_log_var):
# Sample epsilon from standard normal distribution
eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
# note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
# i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
z = z_mu + eps * torch.exp(z_log_var/2.)
return z
def encoder(self, features, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
x = torch.cat((features, onehot_targets), dim=1)
### ENCODER
x = self.hidden_1(x)
x = F.leaky_relu(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
encoded = self.reparameterize(z_mean, z_log_var)
return z_mean, z_log_var, encoded
def decoder(self, encoded, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
encoded = torch.cat((encoded, onehot_targets), dim=1)
### DECODER
x = self.linear_3(encoded)
x = F.leaky_relu(x)
x = self.linear_4(x)
decoded = torch.sigmoid(x)
return decoded
def forward(self, features, targets):
z_mean, z_log_var, encoded = self.encoder(features, targets)
decoded = self.decoder(encoded, targets)
return z_mean, z_log_var, encoded, decoded
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
num_hidden_1,
num_latent,
num_classes)
model = model.to(device)
##########################
### COST AND OPTIMIZER
##########################
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = time.time()
for epoch in range(num_epochs):
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.view(-1, 28*28).to(device)
targets = targets.to(device)
### FORWARD AND BACK PROP
z_mean, z_log_var, encoded, decoded = model(features, targets)
# cost = reconstruction loss + Kullback-Leibler divergence
kl_divergence = (0.5 * (z_mean**2 +
torch.exp(z_log_var) - z_log_var - 1)).sum()
# add condition
x_con = torch.cat((features, to_onehot(targets, num_classes, device)), dim=1)
pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
cost = kl_divergence + pixelwise_bce
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_loader)//batch_size, cost))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/003 | Cost: 71038.5859 Epoch: 001/050 | Batch 050/003 | Cost: 27687.3066 Epoch: 001/050 | Batch 100/003 | Cost: 22471.4355 Epoch: 001/050 | Batch 150/003 | Cost: 20541.7598 Epoch: 001/050 | Batch 200/003 | Cost: 19670.0254 Epoch: 001/050 | Batch 250/003 | Cost: 18063.1172 Epoch: 001/050 | Batch 300/003 | Cost: 18197.8105 Epoch: 001/050 | Batch 350/003 | Cost: 17140.0234 Epoch: 001/050 | Batch 400/003 | Cost: 17088.2832 Epoch: 001/050 | Batch 450/003 | Cost: 16489.5078 Time elapsed: 0.12 min Epoch: 002/050 | Batch 000/003 | Cost: 16543.6680 Epoch: 002/050 | Batch 050/003 | Cost: 15903.9375 Epoch: 002/050 | Batch 100/003 | Cost: 15624.4824 Epoch: 002/050 | Batch 150/003 | Cost: 15635.6016 Epoch: 002/050 | Batch 200/003 | Cost: 15245.4121 Epoch: 002/050 | Batch 250/003 | Cost: 15592.2266 Epoch: 002/050 | Batch 300/003 | Cost: 14563.4980 Epoch: 002/050 | Batch 350/003 | Cost: 15090.5811 Epoch: 002/050 | Batch 400/003 | Cost: 14832.8682 Epoch: 002/050 | Batch 450/003 | Cost: 15250.2480 Time elapsed: 0.24 min Epoch: 003/050 | Batch 000/003 | Cost: 14894.0859 Epoch: 003/050 | Batch 050/003 | Cost: 14885.7168 Epoch: 003/050 | Batch 100/003 | Cost: 14998.2500 Epoch: 003/050 | Batch 150/003 | Cost: 14456.1309 Epoch: 003/050 | Batch 200/003 | Cost: 14737.8740 Epoch: 003/050 | Batch 250/003 | Cost: 14265.4639 Epoch: 003/050 | Batch 300/003 | Cost: 13826.1016 Epoch: 003/050 | Batch 350/003 | Cost: 14193.5059 Epoch: 003/050 | Batch 400/003 | Cost: 14308.9688 Epoch: 003/050 | Batch 450/003 | Cost: 13530.7305 Time elapsed: 0.37 min Epoch: 004/050 | Batch 000/003 | Cost: 14149.7422 Epoch: 004/050 | Batch 050/003 | Cost: 14044.5645 Epoch: 004/050 | Batch 100/003 | Cost: 13909.8926 Epoch: 004/050 | Batch 150/003 | Cost: 14146.3789 Epoch: 004/050 | Batch 200/003 | Cost: 13917.0820 Epoch: 004/050 | Batch 250/003 | Cost: 13873.8262 Epoch: 004/050 | Batch 300/003 | Cost: 13358.3555 Epoch: 004/050 | Batch 350/003 | Cost: 13450.4238 Epoch: 004/050 | Batch 400/003 | Cost: 13677.4062 Epoch: 004/050 | Batch 450/003 | Cost: 14247.6426 Time elapsed: 0.49 min Epoch: 005/050 | Batch 000/003 | Cost: 13364.3730 Epoch: 005/050 | Batch 050/003 | Cost: 13535.6279 Epoch: 005/050 | Batch 100/003 | Cost: 13827.6201 Epoch: 005/050 | Batch 150/003 | Cost: 13421.4111 Epoch: 005/050 | Batch 200/003 | Cost: 13467.9238 Epoch: 005/050 | Batch 250/003 | Cost: 13812.4131 Epoch: 005/050 | Batch 300/003 | Cost: 13457.0234 Epoch: 005/050 | Batch 350/003 | Cost: 14082.8926 Epoch: 005/050 | Batch 400/003 | Cost: 13224.9209 Epoch: 005/050 | Batch 450/003 | Cost: 13151.8311 Time elapsed: 0.61 min Epoch: 006/050 | Batch 000/003 | Cost: 13850.8398 Epoch: 006/050 | Batch 050/003 | Cost: 13386.6289 Epoch: 006/050 | Batch 100/003 | Cost: 13522.1650 Epoch: 006/050 | Batch 150/003 | Cost: 12865.0898 Epoch: 006/050 | Batch 200/003 | Cost: 13949.3652 Epoch: 006/050 | Batch 250/003 | Cost: 12964.2607 Epoch: 006/050 | Batch 300/003 | Cost: 13692.9707 Epoch: 006/050 | Batch 350/003 | Cost: 13319.4219 Epoch: 006/050 | Batch 400/003 | Cost: 13278.6582 Epoch: 006/050 | Batch 450/003 | Cost: 13224.0107 Time elapsed: 0.73 min Epoch: 007/050 | Batch 000/003 | Cost: 13686.9590 Epoch: 007/050 | Batch 050/003 | Cost: 13188.3828 Epoch: 007/050 | Batch 100/003 | Cost: 13173.3457 Epoch: 007/050 | Batch 150/003 | Cost: 13309.3906 Epoch: 007/050 | Batch 200/003 | Cost: 13174.3359 Epoch: 007/050 | Batch 250/003 | Cost: 13169.7100 Epoch: 007/050 | Batch 300/003 | Cost: 13255.1074 Epoch: 007/050 | Batch 350/003 | Cost: 13153.2480 Epoch: 007/050 | Batch 400/003 | Cost: 13683.6426 Epoch: 007/050 | Batch 450/003 | Cost: 13518.1055 Time elapsed: 0.85 min Epoch: 008/050 | Batch 000/003 | Cost: 13531.5840 Epoch: 008/050 | Batch 050/003 | Cost: 12898.9746 Epoch: 008/050 | Batch 100/003 | Cost: 13434.8193 Epoch: 008/050 | Batch 150/003 | Cost: 13045.6172 Epoch: 008/050 | Batch 200/003 | Cost: 13175.8867 Epoch: 008/050 | Batch 250/003 | Cost: 12952.2139 Epoch: 008/050 | Batch 300/003 | Cost: 13100.8184 Epoch: 008/050 | Batch 350/003 | Cost: 13419.0410 Epoch: 008/050 | Batch 400/003 | Cost: 13768.5234 Epoch: 008/050 | Batch 450/003 | Cost: 13300.7402 Time elapsed: 0.97 min Epoch: 009/050 | Batch 000/003 | Cost: 13035.2109 Epoch: 009/050 | Batch 050/003 | Cost: 13260.9902 Epoch: 009/050 | Batch 100/003 | Cost: 13446.3652 Epoch: 009/050 | Batch 150/003 | Cost: 13241.1523 Epoch: 009/050 | Batch 200/003 | Cost: 13139.6904 Epoch: 009/050 | Batch 250/003 | Cost: 13418.0098 Epoch: 009/050 | Batch 300/003 | Cost: 12743.2139 Epoch: 009/050 | Batch 350/003 | Cost: 13347.6465 Epoch: 009/050 | Batch 400/003 | Cost: 13353.1543 Epoch: 009/050 | Batch 450/003 | Cost: 13161.7549 Time elapsed: 1.09 min Epoch: 010/050 | Batch 000/003 | Cost: 13043.7656 Epoch: 010/050 | Batch 050/003 | Cost: 13368.1816 Epoch: 010/050 | Batch 100/003 | Cost: 12742.3281 Epoch: 010/050 | Batch 150/003 | Cost: 13022.2832 Epoch: 010/050 | Batch 200/003 | Cost: 13076.5967 Epoch: 010/050 | Batch 250/003 | Cost: 12765.7480 Epoch: 010/050 | Batch 300/003 | Cost: 12605.2461 Epoch: 010/050 | Batch 350/003 | Cost: 12561.1895 Epoch: 010/050 | Batch 400/003 | Cost: 13314.3623 Epoch: 010/050 | Batch 450/003 | Cost: 13048.6670 Time elapsed: 1.21 min Epoch: 011/050 | Batch 000/003 | Cost: 12769.8984 Epoch: 011/050 | Batch 050/003 | Cost: 13002.4824 Epoch: 011/050 | Batch 100/003 | Cost: 12989.9756 Epoch: 011/050 | Batch 150/003 | Cost: 12345.3877 Epoch: 011/050 | Batch 200/003 | Cost: 12599.8975 Epoch: 011/050 | Batch 250/003 | Cost: 13375.6543 Epoch: 011/050 | Batch 300/003 | Cost: 12829.1621 Epoch: 011/050 | Batch 350/003 | Cost: 12245.9316 Epoch: 011/050 | Batch 400/003 | Cost: 12745.7324 Epoch: 011/050 | Batch 450/003 | Cost: 12833.6787 Time elapsed: 1.33 min Epoch: 012/050 | Batch 000/003 | Cost: 13168.7441 Epoch: 012/050 | Batch 050/003 | Cost: 12641.6289 Epoch: 012/050 | Batch 100/003 | Cost: 12435.3809 Epoch: 012/050 | Batch 150/003 | Cost: 12635.2422 Epoch: 012/050 | Batch 200/003 | Cost: 12657.0459 Epoch: 012/050 | Batch 250/003 | Cost: 12965.5098 Epoch: 012/050 | Batch 300/003 | Cost: 13022.1289 Epoch: 012/050 | Batch 350/003 | Cost: 13023.3535 Epoch: 012/050 | Batch 400/003 | Cost: 12723.5039 Epoch: 012/050 | Batch 450/003 | Cost: 12593.6680 Time elapsed: 1.45 min Epoch: 013/050 | Batch 000/003 | Cost: 13062.8193 Epoch: 013/050 | Batch 050/003 | Cost: 12942.4668 Epoch: 013/050 | Batch 100/003 | Cost: 13116.7461 Epoch: 013/050 | Batch 150/003 | Cost: 12458.0615 Epoch: 013/050 | Batch 200/003 | Cost: 12586.8535 Epoch: 013/050 | Batch 250/003 | Cost: 12546.0703 Epoch: 013/050 | Batch 300/003 | Cost: 12817.3428 Epoch: 013/050 | Batch 350/003 | Cost: 13081.0586 Epoch: 013/050 | Batch 400/003 | Cost: 13006.0645 Epoch: 013/050 | Batch 450/003 | Cost: 12553.9902 Time elapsed: 1.57 min Epoch: 014/050 | Batch 000/003 | Cost: 12912.1406 Epoch: 014/050 | Batch 050/003 | Cost: 12546.7812 Epoch: 014/050 | Batch 100/003 | Cost: 12928.7334 Epoch: 014/050 | Batch 150/003 | Cost: 12532.9268 Epoch: 014/050 | Batch 200/003 | Cost: 12385.5430 Epoch: 014/050 | Batch 250/003 | Cost: 12737.9053 Epoch: 014/050 | Batch 300/003 | Cost: 12179.8184 Epoch: 014/050 | Batch 350/003 | Cost: 12855.8301 Epoch: 014/050 | Batch 400/003 | Cost: 12984.9512 Epoch: 014/050 | Batch 450/003 | Cost: 12996.2832 Time elapsed: 1.69 min Epoch: 015/050 | Batch 000/003 | Cost: 12615.2100 Epoch: 015/050 | Batch 050/003 | Cost: 13215.7402 Epoch: 015/050 | Batch 100/003 | Cost: 12532.8447 Epoch: 015/050 | Batch 150/003 | Cost: 12863.5928 Epoch: 015/050 | Batch 200/003 | Cost: 12683.9209 Epoch: 015/050 | Batch 250/003 | Cost: 12564.1309 Epoch: 015/050 | Batch 300/003 | Cost: 12535.2500 Epoch: 015/050 | Batch 350/003 | Cost: 12853.1445 Epoch: 015/050 | Batch 400/003 | Cost: 12600.9902 Epoch: 015/050 | Batch 450/003 | Cost: 12867.0781 Time elapsed: 1.81 min Epoch: 016/050 | Batch 000/003 | Cost: 12429.7402 Epoch: 016/050 | Batch 050/003 | Cost: 13151.1377 Epoch: 016/050 | Batch 100/003 | Cost: 12885.5371 Epoch: 016/050 | Batch 150/003 | Cost: 12601.3242 Epoch: 016/050 | Batch 200/003 | Cost: 13196.4834 Epoch: 016/050 | Batch 250/003 | Cost: 12570.6836 Epoch: 016/050 | Batch 300/003 | Cost: 12942.7861 Epoch: 016/050 | Batch 350/003 | Cost: 12389.7363 Epoch: 016/050 | Batch 400/003 | Cost: 12576.1445 Epoch: 016/050 | Batch 450/003 | Cost: 12140.7900 Time elapsed: 1.93 min Epoch: 017/050 | Batch 000/003 | Cost: 12978.2227 Epoch: 017/050 | Batch 050/003 | Cost: 12447.1230 Epoch: 017/050 | Batch 100/003 | Cost: 12980.0459 Epoch: 017/050 | Batch 150/003 | Cost: 12901.1045 Epoch: 017/050 | Batch 200/003 | Cost: 12381.2070 Epoch: 017/050 | Batch 250/003 | Cost: 12956.8857 Epoch: 017/050 | Batch 300/003 | Cost: 12341.9512 Epoch: 017/050 | Batch 350/003 | Cost: 12692.6270 Epoch: 017/050 | Batch 400/003 | Cost: 12316.9727 Epoch: 017/050 | Batch 450/003 | Cost: 12857.4844 Time elapsed: 2.05 min Epoch: 018/050 | Batch 000/003 | Cost: 12531.7217 Epoch: 018/050 | Batch 050/003 | Cost: 12500.7012 Epoch: 018/050 | Batch 100/003 | Cost: 12458.2969 Epoch: 018/050 | Batch 150/003 | Cost: 12838.6758 Epoch: 018/050 | Batch 200/003 | Cost: 12678.9072 Epoch: 018/050 | Batch 250/003 | Cost: 12199.8320 Epoch: 018/050 | Batch 300/003 | Cost: 12352.8457 Epoch: 018/050 | Batch 350/003 | Cost: 12980.6797 Epoch: 018/050 | Batch 400/003 | Cost: 11996.0254 Epoch: 018/050 | Batch 450/003 | Cost: 12993.7158 Time elapsed: 2.17 min Epoch: 019/050 | Batch 000/003 | Cost: 12559.5215 Epoch: 019/050 | Batch 050/003 | Cost: 12757.8662 Epoch: 019/050 | Batch 100/003 | Cost: 13118.6836 Epoch: 019/050 | Batch 150/003 | Cost: 13059.1777 Epoch: 019/050 | Batch 200/003 | Cost: 13048.2031 Epoch: 019/050 | Batch 250/003 | Cost: 12460.9688 Epoch: 019/050 | Batch 300/003 | Cost: 12757.1748 Epoch: 019/050 | Batch 350/003 | Cost: 12035.1006 Epoch: 019/050 | Batch 400/003 | Cost: 12802.9883 Epoch: 019/050 | Batch 450/003 | Cost: 12858.0488 Time elapsed: 2.29 min Epoch: 020/050 | Batch 000/003 | Cost: 12934.9336 Epoch: 020/050 | Batch 050/003 | Cost: 12974.0488 Epoch: 020/050 | Batch 100/003 | Cost: 12512.7949 Epoch: 020/050 | Batch 150/003 | Cost: 12980.7275 Epoch: 020/050 | Batch 200/003 | Cost: 12930.8789 Epoch: 020/050 | Batch 250/003 | Cost: 12721.5723 Epoch: 020/050 | Batch 300/003 | Cost: 12933.5918 Epoch: 020/050 | Batch 350/003 | Cost: 12456.5488 Epoch: 020/050 | Batch 400/003 | Cost: 12849.0654 Epoch: 020/050 | Batch 450/003 | Cost: 12522.7480 Time elapsed: 2.41 min Epoch: 021/050 | Batch 000/003 | Cost: 12122.0742 Epoch: 021/050 | Batch 050/003 | Cost: 13007.4697 Epoch: 021/050 | Batch 100/003 | Cost: 12690.0635 Epoch: 021/050 | Batch 150/003 | Cost: 12353.1621 Epoch: 021/050 | Batch 200/003 | Cost: 13007.6445 Epoch: 021/050 | Batch 250/003 | Cost: 12202.5176 Epoch: 021/050 | Batch 300/003 | Cost: 12154.6660 Epoch: 021/050 | Batch 350/003 | Cost: 12723.7158 Epoch: 021/050 | Batch 400/003 | Cost: 12902.6582 Epoch: 021/050 | Batch 450/003 | Cost: 12786.1484 Time elapsed: 2.53 min Epoch: 022/050 | Batch 000/003 | Cost: 12648.4844 Epoch: 022/050 | Batch 050/003 | Cost: 12816.2891 Epoch: 022/050 | Batch 100/003 | Cost: 12544.2412 Epoch: 022/050 | Batch 150/003 | Cost: 12651.0215 Epoch: 022/050 | Batch 200/003 | Cost: 12831.6562 Epoch: 022/050 | Batch 250/003 | Cost: 12590.9326 Epoch: 022/050 | Batch 300/003 | Cost: 12396.2373 Epoch: 022/050 | Batch 350/003 | Cost: 12948.6094 Epoch: 022/050 | Batch 400/003 | Cost: 12553.6816 Epoch: 022/050 | Batch 450/003 | Cost: 12309.2637 Time elapsed: 2.65 min Epoch: 023/050 | Batch 000/003 | Cost: 12857.9453 Epoch: 023/050 | Batch 050/003 | Cost: 12910.1377 Epoch: 023/050 | Batch 100/003 | Cost: 12449.3242 Epoch: 023/050 | Batch 150/003 | Cost: 12278.5312 Epoch: 023/050 | Batch 200/003 | Cost: 12971.6885 Epoch: 023/050 | Batch 250/003 | Cost: 13084.1699 Epoch: 023/050 | Batch 300/003 | Cost: 12463.8232 Epoch: 023/050 | Batch 350/003 | Cost: 12589.3398 Epoch: 023/050 | Batch 400/003 | Cost: 12732.2168 Epoch: 023/050 | Batch 450/003 | Cost: 12196.4492 Time elapsed: 2.76 min Epoch: 024/050 | Batch 000/003 | Cost: 12342.8789 Epoch: 024/050 | Batch 050/003 | Cost: 12255.4883 Epoch: 024/050 | Batch 100/003 | Cost: 12158.4902 Epoch: 024/050 | Batch 150/003 | Cost: 12731.5742 Epoch: 024/050 | Batch 200/003 | Cost: 12789.7168 Epoch: 024/050 | Batch 250/003 | Cost: 12213.1104 Epoch: 024/050 | Batch 300/003 | Cost: 12613.8281 Epoch: 024/050 | Batch 350/003 | Cost: 12530.3096 Epoch: 024/050 | Batch 400/003 | Cost: 12475.6035 Epoch: 024/050 | Batch 450/003 | Cost: 12182.7178 Time elapsed: 2.88 min Epoch: 025/050 | Batch 000/003 | Cost: 12929.5430 Epoch: 025/050 | Batch 050/003 | Cost: 12472.9180 Epoch: 025/050 | Batch 100/003 | Cost: 11870.2754 Epoch: 025/050 | Batch 150/003 | Cost: 12619.1641 Epoch: 025/050 | Batch 200/003 | Cost: 12285.2344 Epoch: 025/050 | Batch 250/003 | Cost: 12557.6367 Epoch: 025/050 | Batch 300/003 | Cost: 12409.3574 Epoch: 025/050 | Batch 350/003 | Cost: 12889.7910 Epoch: 025/050 | Batch 400/003 | Cost: 12708.5605 Epoch: 025/050 | Batch 450/003 | Cost: 12577.1514 Time elapsed: 3.00 min Epoch: 026/050 | Batch 000/003 | Cost: 12301.7139 Epoch: 026/050 | Batch 050/003 | Cost: 12692.7188 Epoch: 026/050 | Batch 100/003 | Cost: 12601.7607 Epoch: 026/050 | Batch 150/003 | Cost: 12460.5254 Epoch: 026/050 | Batch 200/003 | Cost: 12769.4287 Epoch: 026/050 | Batch 250/003 | Cost: 12428.0000 Epoch: 026/050 | Batch 300/003 | Cost: 12987.5449 Epoch: 026/050 | Batch 350/003 | Cost: 12646.5908 Epoch: 026/050 | Batch 400/003 | Cost: 12335.6738 Epoch: 026/050 | Batch 450/003 | Cost: 12613.5449 Time elapsed: 3.12 min Epoch: 027/050 | Batch 000/003 | Cost: 12647.8809 Epoch: 027/050 | Batch 050/003 | Cost: 12970.0127 Epoch: 027/050 | Batch 100/003 | Cost: 12870.9219 Epoch: 027/050 | Batch 150/003 | Cost: 12435.1553 Epoch: 027/050 | Batch 200/003 | Cost: 12810.8418 Epoch: 027/050 | Batch 250/003 | Cost: 12727.6777 Epoch: 027/050 | Batch 300/003 | Cost: 12762.6055 Epoch: 027/050 | Batch 350/003 | Cost: 12970.4414 Epoch: 027/050 | Batch 400/003 | Cost: 12745.8652 Epoch: 027/050 | Batch 450/003 | Cost: 12442.3232 Time elapsed: 3.24 min Epoch: 028/050 | Batch 000/003 | Cost: 12199.5078 Epoch: 028/050 | Batch 050/003 | Cost: 12707.5625 Epoch: 028/050 | Batch 100/003 | Cost: 12289.9277 Epoch: 028/050 | Batch 150/003 | Cost: 12375.3242 Epoch: 028/050 | Batch 200/003 | Cost: 12023.3887 Epoch: 028/050 | Batch 250/003 | Cost: 12776.7168 Epoch: 028/050 | Batch 300/003 | Cost: 12680.4668 Epoch: 028/050 | Batch 350/003 | Cost: 12701.3281 Epoch: 028/050 | Batch 400/003 | Cost: 12561.7227 Epoch: 028/050 | Batch 450/003 | Cost: 12763.8447 Time elapsed: 3.36 min Epoch: 029/050 | Batch 000/003 | Cost: 12721.6777 Epoch: 029/050 | Batch 050/003 | Cost: 12443.0645 Epoch: 029/050 | Batch 100/003 | Cost: 12057.7822 Epoch: 029/050 | Batch 150/003 | Cost: 12504.2529 Epoch: 029/050 | Batch 200/003 | Cost: 12310.3965 Epoch: 029/050 | Batch 250/003 | Cost: 13202.6211 Epoch: 029/050 | Batch 300/003 | Cost: 12117.8008 Epoch: 029/050 | Batch 350/003 | Cost: 12538.9092 Epoch: 029/050 | Batch 400/003 | Cost: 12451.9180 Epoch: 029/050 | Batch 450/003 | Cost: 12649.5537 Time elapsed: 3.48 min Epoch: 030/050 | Batch 000/003 | Cost: 13177.7520 Epoch: 030/050 | Batch 050/003 | Cost: 12634.2002 Epoch: 030/050 | Batch 100/003 | Cost: 12582.4863 Epoch: 030/050 | Batch 150/003 | Cost: 12516.8877 Epoch: 030/050 | Batch 200/003 | Cost: 12460.7139 Epoch: 030/050 | Batch 250/003 | Cost: 12385.2090 Epoch: 030/050 | Batch 300/003 | Cost: 12847.1113 Epoch: 030/050 | Batch 350/003 | Cost: 12123.1543 Epoch: 030/050 | Batch 400/003 | Cost: 12427.7227 Epoch: 030/050 | Batch 450/003 | Cost: 12904.1279 Time elapsed: 3.60 min Epoch: 031/050 | Batch 000/003 | Cost: 12551.2178 Epoch: 031/050 | Batch 050/003 | Cost: 13006.6875 Epoch: 031/050 | Batch 100/003 | Cost: 12672.4551 Epoch: 031/050 | Batch 150/003 | Cost: 12577.4131 Epoch: 031/050 | Batch 200/003 | Cost: 12595.9150 Epoch: 031/050 | Batch 250/003 | Cost: 12294.5635 Epoch: 031/050 | Batch 300/003 | Cost: 12491.6406 Epoch: 031/050 | Batch 350/003 | Cost: 12726.5947 Epoch: 031/050 | Batch 400/003 | Cost: 12449.7246 Epoch: 031/050 | Batch 450/003 | Cost: 12771.0234 Time elapsed: 3.72 min Epoch: 032/050 | Batch 000/003 | Cost: 12567.2988 Epoch: 032/050 | Batch 050/003 | Cost: 11960.7676 Epoch: 032/050 | Batch 100/003 | Cost: 12276.4648 Epoch: 032/050 | Batch 150/003 | Cost: 13205.2402 Epoch: 032/050 | Batch 200/003 | Cost: 12931.1514 Epoch: 032/050 | Batch 250/003 | Cost: 12975.4473 Epoch: 032/050 | Batch 300/003 | Cost: 12364.8164 Epoch: 032/050 | Batch 350/003 | Cost: 13167.5020 Epoch: 032/050 | Batch 400/003 | Cost: 12439.9355 Epoch: 032/050 | Batch 450/003 | Cost: 12526.5000 Time elapsed: 3.84 min Epoch: 033/050 | Batch 000/003 | Cost: 12447.2412 Epoch: 033/050 | Batch 050/003 | Cost: 12603.1895 Epoch: 033/050 | Batch 100/003 | Cost: 11774.2539 Epoch: 033/050 | Batch 150/003 | Cost: 12859.1406 Epoch: 033/050 | Batch 200/003 | Cost: 12321.5195 Epoch: 033/050 | Batch 250/003 | Cost: 12458.5352 Epoch: 033/050 | Batch 300/003 | Cost: 12871.9336 Epoch: 033/050 | Batch 350/003 | Cost: 12373.0039 Epoch: 033/050 | Batch 400/003 | Cost: 12674.4531 Epoch: 033/050 | Batch 450/003 | Cost: 12425.8633 Time elapsed: 3.96 min Epoch: 034/050 | Batch 000/003 | Cost: 12650.1582 Epoch: 034/050 | Batch 050/003 | Cost: 12587.3613 Epoch: 034/050 | Batch 100/003 | Cost: 12675.3105 Epoch: 034/050 | Batch 150/003 | Cost: 12833.5391 Epoch: 034/050 | Batch 200/003 | Cost: 12441.7305 Epoch: 034/050 | Batch 250/003 | Cost: 12733.2959 Epoch: 034/050 | Batch 300/003 | Cost: 12329.9219 Epoch: 034/050 | Batch 350/003 | Cost: 12802.7354 Epoch: 034/050 | Batch 400/003 | Cost: 11619.9902 Epoch: 034/050 | Batch 450/003 | Cost: 12474.8330 Time elapsed: 4.08 min Epoch: 035/050 | Batch 000/003 | Cost: 12535.9199 Epoch: 035/050 | Batch 050/003 | Cost: 12450.3691 Epoch: 035/050 | Batch 100/003 | Cost: 12693.5137 Epoch: 035/050 | Batch 150/003 | Cost: 12554.5850 Epoch: 035/050 | Batch 200/003 | Cost: 12123.2129 Epoch: 035/050 | Batch 250/003 | Cost: 12241.1787 Epoch: 035/050 | Batch 300/003 | Cost: 12227.1240 Epoch: 035/050 | Batch 350/003 | Cost: 12536.8457 Epoch: 035/050 | Batch 400/003 | Cost: 12318.9238 Epoch: 035/050 | Batch 450/003 | Cost: 12698.3252 Time elapsed: 4.20 min Epoch: 036/050 | Batch 000/003 | Cost: 12412.0586 Epoch: 036/050 | Batch 050/003 | Cost: 12438.2656 Epoch: 036/050 | Batch 100/003 | Cost: 12465.1973 Epoch: 036/050 | Batch 150/003 | Cost: 12164.7285 Epoch: 036/050 | Batch 200/003 | Cost: 12492.0488 Epoch: 036/050 | Batch 250/003 | Cost: 12558.4023 Epoch: 036/050 | Batch 300/003 | Cost: 12648.7227 Epoch: 036/050 | Batch 350/003 | Cost: 12408.2832 Epoch: 036/050 | Batch 400/003 | Cost: 12169.1201 Epoch: 036/050 | Batch 450/003 | Cost: 12748.4707 Time elapsed: 4.32 min Epoch: 037/050 | Batch 000/003 | Cost: 12361.0439 Epoch: 037/050 | Batch 050/003 | Cost: 12553.2051 Epoch: 037/050 | Batch 100/003 | Cost: 12548.9307 Epoch: 037/050 | Batch 150/003 | Cost: 12064.2344 Epoch: 037/050 | Batch 200/003 | Cost: 12219.1270 Epoch: 037/050 | Batch 250/003 | Cost: 12353.5186 Epoch: 037/050 | Batch 300/003 | Cost: 12224.3682 Epoch: 037/050 | Batch 350/003 | Cost: 12668.8193 Epoch: 037/050 | Batch 400/003 | Cost: 12274.8691 Epoch: 037/050 | Batch 450/003 | Cost: 12393.6182 Time elapsed: 4.44 min Epoch: 038/050 | Batch 000/003 | Cost: 12669.8281 Epoch: 038/050 | Batch 050/003 | Cost: 12271.3066 Epoch: 038/050 | Batch 100/003 | Cost: 12387.7930 Epoch: 038/050 | Batch 150/003 | Cost: 11961.1123 Epoch: 038/050 | Batch 200/003 | Cost: 11976.5498 Epoch: 038/050 | Batch 250/003 | Cost: 12573.2324 Epoch: 038/050 | Batch 300/003 | Cost: 12608.4854 Epoch: 038/050 | Batch 350/003 | Cost: 12112.7617 Epoch: 038/050 | Batch 400/003 | Cost: 12275.8418 Epoch: 038/050 | Batch 450/003 | Cost: 12417.7549 Time elapsed: 4.56 min Epoch: 039/050 | Batch 000/003 | Cost: 12874.1016 Epoch: 039/050 | Batch 050/003 | Cost: 12372.2070 Epoch: 039/050 | Batch 100/003 | Cost: 12446.2695 Epoch: 039/050 | Batch 150/003 | Cost: 13140.2764 Epoch: 039/050 | Batch 200/003 | Cost: 12825.3037 Epoch: 039/050 | Batch 250/003 | Cost: 12165.7451 Epoch: 039/050 | Batch 300/003 | Cost: 12430.3340 Epoch: 039/050 | Batch 350/003 | Cost: 12702.8613 Epoch: 039/050 | Batch 400/003 | Cost: 12374.5752 Epoch: 039/050 | Batch 450/003 | Cost: 12414.6475 Time elapsed: 4.68 min Epoch: 040/050 | Batch 000/003 | Cost: 12147.0693 Epoch: 040/050 | Batch 050/003 | Cost: 12907.0684 Epoch: 040/050 | Batch 100/003 | Cost: 11664.0801 Epoch: 040/050 | Batch 150/003 | Cost: 12443.9512 Epoch: 040/050 | Batch 200/003 | Cost: 12112.6250 Epoch: 040/050 | Batch 250/003 | Cost: 12146.6133 Epoch: 040/050 | Batch 300/003 | Cost: 12050.8281 Epoch: 040/050 | Batch 350/003 | Cost: 12762.5020 Epoch: 040/050 | Batch 400/003 | Cost: 12517.0771 Epoch: 040/050 | Batch 450/003 | Cost: 12191.8916 Time elapsed: 4.80 min Epoch: 041/050 | Batch 000/003 | Cost: 12098.7090 Epoch: 041/050 | Batch 050/003 | Cost: 12562.2539 Epoch: 041/050 | Batch 100/003 | Cost: 12609.0088 Epoch: 041/050 | Batch 150/003 | Cost: 12270.9854 Epoch: 041/050 | Batch 200/003 | Cost: 12757.7578 Epoch: 041/050 | Batch 250/003 | Cost: 12277.8584 Epoch: 041/050 | Batch 300/003 | Cost: 12219.4121 Epoch: 041/050 | Batch 350/003 | Cost: 12215.5977 Epoch: 041/050 | Batch 400/003 | Cost: 12476.9668 Epoch: 041/050 | Batch 450/003 | Cost: 12220.5449 Time elapsed: 4.92 min Epoch: 042/050 | Batch 000/003 | Cost: 12391.3916 Epoch: 042/050 | Batch 050/003 | Cost: 12526.2002 Epoch: 042/050 | Batch 100/003 | Cost: 12929.2305 Epoch: 042/050 | Batch 150/003 | Cost: 12575.8652 Epoch: 042/050 | Batch 200/003 | Cost: 12588.2656 Epoch: 042/050 | Batch 250/003 | Cost: 12191.2520 Epoch: 042/050 | Batch 300/003 | Cost: 12416.6172 Epoch: 042/050 | Batch 350/003 | Cost: 12398.3096 Epoch: 042/050 | Batch 400/003 | Cost: 12093.6777 Epoch: 042/050 | Batch 450/003 | Cost: 12391.6504 Time elapsed: 5.04 min Epoch: 043/050 | Batch 000/003 | Cost: 12748.3145 Epoch: 043/050 | Batch 050/003 | Cost: 12076.2344 Epoch: 043/050 | Batch 100/003 | Cost: 11953.4248 Epoch: 043/050 | Batch 150/003 | Cost: 12420.9707 Epoch: 043/050 | Batch 200/003 | Cost: 12428.2041 Epoch: 043/050 | Batch 250/003 | Cost: 12447.2324 Epoch: 043/050 | Batch 300/003 | Cost: 12404.4004 Epoch: 043/050 | Batch 350/003 | Cost: 13077.8926 Epoch: 043/050 | Batch 400/003 | Cost: 13071.2891 Epoch: 043/050 | Batch 450/003 | Cost: 12398.8311 Time elapsed: 5.16 min Epoch: 044/050 | Batch 000/003 | Cost: 12913.1631 Epoch: 044/050 | Batch 050/003 | Cost: 12169.1523 Epoch: 044/050 | Batch 100/003 | Cost: 11856.3672 Epoch: 044/050 | Batch 150/003 | Cost: 12280.6045 Epoch: 044/050 | Batch 200/003 | Cost: 12343.7998 Epoch: 044/050 | Batch 250/003 | Cost: 12746.3164 Epoch: 044/050 | Batch 300/003 | Cost: 12279.5156 Epoch: 044/050 | Batch 350/003 | Cost: 12548.7598 Epoch: 044/050 | Batch 400/003 | Cost: 12430.6104 Epoch: 044/050 | Batch 450/003 | Cost: 12775.3105 Time elapsed: 5.28 min Epoch: 045/050 | Batch 000/003 | Cost: 12245.9482 Epoch: 045/050 | Batch 050/003 | Cost: 12547.1270 Epoch: 045/050 | Batch 100/003 | Cost: 12214.5732 Epoch: 045/050 | Batch 150/003 | Cost: 12484.7715 Epoch: 045/050 | Batch 200/003 | Cost: 12552.6123 Epoch: 045/050 | Batch 250/003 | Cost: 12510.6064 Epoch: 045/050 | Batch 300/003 | Cost: 12465.5566 Epoch: 045/050 | Batch 350/003 | Cost: 12111.4629 Epoch: 045/050 | Batch 400/003 | Cost: 12435.2578 Epoch: 045/050 | Batch 450/003 | Cost: 12433.7461 Time elapsed: 5.40 min Epoch: 046/050 | Batch 000/003 | Cost: 12598.7734 Epoch: 046/050 | Batch 050/003 | Cost: 12568.2207 Epoch: 046/050 | Batch 100/003 | Cost: 12853.6328 Epoch: 046/050 | Batch 150/003 | Cost: 12230.1533 Epoch: 046/050 | Batch 200/003 | Cost: 12326.4727 Epoch: 046/050 | Batch 250/003 | Cost: 12770.4814 Epoch: 046/050 | Batch 300/003 | Cost: 12082.1006 Epoch: 046/050 | Batch 350/003 | Cost: 11948.0430 Epoch: 046/050 | Batch 400/003 | Cost: 12462.7773 Epoch: 046/050 | Batch 450/003 | Cost: 12145.4609 Time elapsed: 5.52 min Epoch: 047/050 | Batch 000/003 | Cost: 12146.9893 Epoch: 047/050 | Batch 050/003 | Cost: 11917.3779 Epoch: 047/050 | Batch 100/003 | Cost: 12135.3457 Epoch: 047/050 | Batch 150/003 | Cost: 11984.8691 Epoch: 047/050 | Batch 200/003 | Cost: 12437.4248 Epoch: 047/050 | Batch 250/003 | Cost: 12444.1416 Epoch: 047/050 | Batch 300/003 | Cost: 12078.9043 Epoch: 047/050 | Batch 350/003 | Cost: 12138.3818 Epoch: 047/050 | Batch 400/003 | Cost: 12556.3086 Epoch: 047/050 | Batch 450/003 | Cost: 12726.3828 Time elapsed: 5.64 min Epoch: 048/050 | Batch 000/003 | Cost: 12927.1035 Epoch: 048/050 | Batch 050/003 | Cost: 12747.8994 Epoch: 048/050 | Batch 100/003 | Cost: 12002.6406 Epoch: 048/050 | Batch 150/003 | Cost: 12093.2988 Epoch: 048/050 | Batch 200/003 | Cost: 11541.1982 Epoch: 048/050 | Batch 250/003 | Cost: 12530.8398 Epoch: 048/050 | Batch 300/003 | Cost: 12200.4463 Epoch: 048/050 | Batch 350/003 | Cost: 12818.4082 Epoch: 048/050 | Batch 400/003 | Cost: 12760.9844 Epoch: 048/050 | Batch 450/003 | Cost: 12186.3496 Time elapsed: 5.76 min Epoch: 049/050 | Batch 000/003 | Cost: 12394.8848 Epoch: 049/050 | Batch 050/003 | Cost: 12406.0801 Epoch: 049/050 | Batch 100/003 | Cost: 12390.3223 Epoch: 049/050 | Batch 150/003 | Cost: 12788.7949 Epoch: 049/050 | Batch 200/003 | Cost: 12015.3936 Epoch: 049/050 | Batch 250/003 | Cost: 12259.9609 Epoch: 049/050 | Batch 300/003 | Cost: 12216.9688 Epoch: 049/050 | Batch 350/003 | Cost: 12795.6113 Epoch: 049/050 | Batch 400/003 | Cost: 11885.8955 Epoch: 049/050 | Batch 450/003 | Cost: 12445.2891 Time elapsed: 5.88 min Epoch: 050/050 | Batch 000/003 | Cost: 12782.9785 Epoch: 050/050 | Batch 050/003 | Cost: 12291.9814 Epoch: 050/050 | Batch 100/003 | Cost: 12721.1641 Epoch: 050/050 | Batch 150/003 | Cost: 12482.5762 Epoch: 050/050 | Batch 200/003 | Cost: 12581.8125 Epoch: 050/050 | Batch 250/003 | Cost: 12422.5527 Epoch: 050/050 | Batch 300/003 | Cost: 12384.3047 Epoch: 050/050 | Batch 350/003 | Cost: 12031.4541 Epoch: 050/050 | Batch 400/003 | Cost: 12278.7617 Epoch: 050/050 | Batch 450/003 | Cost: 12190.8232 Time elapsed: 6.00 min Total Training Time: 6.00 min
%matplotlib inline
import matplotlib.pyplot as plt
##########################
### VISUALIZATION
##########################
n_images = 15
image_width = 28
fig, axes = plt.subplots(nrows=2, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images][:, :-num_classes]
for i in range(n_images):
for ax, img in zip(axes, [orig_images, decoded_images]):
curr_img = img[i].detach().to(torch.device('cpu'))
ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')
for i in range(10):
##########################
### RANDOM SAMPLE
##########################
labels = torch.tensor([i]*10).to(device)
n_images = labels.size()[0]
rand_features = torch.randn(n_images, num_latent).to(device)
new_images = model.decoder(rand_features, labels)
##########################
### VISUALIZATION
##########################
image_width = 28
fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
decoded_images = new_images[:n_images][:, :-num_classes]
print('Class Label %d' % i)
for ax, img in zip(axes, decoded_images):
curr_img = img.detach().to(torch.device('cpu'))
ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')
plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
%watermark -iv
numpy 1.15.4 torch 1.0.0