Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.6.8 IPython 7.2.0 torch 1.0.0
A simple convolutional conditional variational autoencoder that compresses 768-pixel MNIST images down to a 50-pixel latent vector representation.
This implementation DOES NOT concatenate the inputs with the class labels when computing the reconstruction loss in contrast to how it is commonly done in non-convolutional conditional variational autoencoders. Not considering class-labels in the reconstruction loss leads to substantially better results compared to the implementation that does concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cnn-cvae.ipynb
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
##########################
### SETTINGS
##########################
# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Device:', device)
# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128
# Architecture
num_classes = 10
num_features = 784
num_latent = 50
##########################
### MNIST DATASET
##########################
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Device: cuda:1 Image batch dimensions: torch.Size([128, 1, 28, 28]) Image label dimensions: torch.Size([128])
##########################
### MODEL
##########################
def to_onehot(labels, num_classes, device):
labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
labels_onehot.scatter_(1, labels.view(-1, 1), 1)
return labels_onehot
class ConditionalVariationalAutoencoder(torch.nn.Module):
def __init__(self, num_features, num_latent, num_classes):
super(ConditionalVariationalAutoencoder, self).__init__()
self.num_classes = num_classes
###############
# ENCODER
##############
# calculate same padding:
# (w - k + 2*p)/s + 1 = o
# => p = (s(o-1) - w + k)/2
self.enc_conv_1 = torch.nn.Conv2d(in_channels=1+self.num_classes,
out_channels=16,
kernel_size=(6, 6),
stride=(2, 2),
padding=0)
self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=(4, 4),
stride=(2, 2),
padding=0)
self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=(2, 2),
stride=(2, 2),
padding=0)
self.z_mean = torch.nn.Linear(64*2*2, num_latent)
# in the original paper (Kingma & Welling 2015, we use
# have a z_mean and z_var, but the problem is that
# the z_var can be negative, which would cause issues
# in the log later. Hence we assume that latent vector
# has a z_mean and z_log_var component, and when we need
# the regular variance or std_dev, we simply use
# an exponential function
self.z_log_var = torch.nn.Linear(64*2*2, num_latent)
###############
# DECODER
##############
self.dec_linear_1 = torch.nn.Linear(num_latent+self.num_classes, 64*2*2)
self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,
out_channels=32,
kernel_size=(2, 2),
stride=(2, 2),
padding=0)
self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,
out_channels=16,
kernel_size=(4, 4),
stride=(3, 3),
padding=1)
self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,
out_channels=1,
kernel_size=(6, 6),
stride=(3, 3),
padding=4)
def reparameterize(self, z_mu, z_log_var):
# Sample epsilon from standard normal distribution
eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
# note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
# i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
z = z_mu + eps * torch.exp(z_log_var/2.)
return z
def encoder(self, features, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
onehot_targets = onehot_targets.view(-1, self.num_classes, 1, 1)
ones = torch.ones(features.size()[0],
self.num_classes,
features.size()[2],
features.size()[3],
dtype=features.dtype).to(device)
ones = ones * onehot_targets
x = torch.cat((features, ones), dim=1)
x = self.enc_conv_1(x)
x = F.leaky_relu(x)
#print('conv1 out:', x.size())
x = self.enc_conv_2(x)
x = F.leaky_relu(x)
#print('conv2 out:', x.size())
x = self.enc_conv_3(x)
x = F.leaky_relu(x)
#print('conv3 out:', x.size())
z_mean = self.z_mean(x.view(-1, 64*2*2))
z_log_var = self.z_log_var(x.view(-1, 64*2*2))
encoded = self.reparameterize(z_mean, z_log_var)
return z_mean, z_log_var, encoded
def decoder(self, encoded, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
encoded = torch.cat((encoded, onehot_targets), dim=1)
x = self.dec_linear_1(encoded)
x = x.view(-1, 64, 2, 2)
x = self.dec_deconv_1(x)
x = F.leaky_relu(x)
#print('deconv1 out:', x.size())
x = self.dec_deconv_2(x)
x = F.leaky_relu(x)
#print('deconv2 out:', x.size())
x = self.dec_deconv_3(x)
x = F.leaky_relu(x)
#print('deconv1 out:', x.size())
decoded = torch.sigmoid(x)
return decoded
def forward(self, features, targets):
z_mean, z_log_var, encoded = self.encoder(features, targets)
decoded = self.decoder(encoded, targets)
return z_mean, z_log_var, encoded, decoded
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
num_latent,
num_classes)
model = model.to(device)
##########################
### COST AND OPTIMIZER
##########################
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = time.time()
for epoch in range(num_epochs):
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(device)
targets = targets.to(device)
### FORWARD AND BACK PROP
z_mean, z_log_var, encoded, decoded = model(features, targets)
# cost = reconstruction loss + Kullback-Leibler divergence
kl_divergence = (0.5 * (z_mean**2 +
torch.exp(z_log_var) - z_log_var - 1)).sum()
### Add condition
# Disabled for reconstruction loss as it gives poor results
"""
onehot_targets = to_onehot(targets, num_classes, device)
onehot_targets = onehot_targets.view(-1, num_classes, 1, 1)
ones = torch.ones(features.size()[0],
num_classes,
features.size()[2],
features.size()[3],
dtype=features.dtype).to(device)
ones = ones * onehot_targets
x_con = torch.cat((features, ones), dim=1)
"""
### Compute loss
#pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')
cost = kl_divergence + pixelwise_bce
### UPDATE MODEL PARAMETERS
optimizer.zero_grad()
cost.backward()
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_loader), cost))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 69516.4531 Epoch: 001/050 | Batch 050/469 | Cost: 65598.3438 Epoch: 001/050 | Batch 100/469 | Cost: 38653.3398 Epoch: 001/050 | Batch 150/469 | Cost: 31647.0234 Epoch: 001/050 | Batch 200/469 | Cost: 29562.4902 Epoch: 001/050 | Batch 250/469 | Cost: 27448.7969 Epoch: 001/050 | Batch 300/469 | Cost: 27149.4609 Epoch: 001/050 | Batch 350/469 | Cost: 25922.6641 Epoch: 001/050 | Batch 400/469 | Cost: 25652.2539 Epoch: 001/050 | Batch 450/469 | Cost: 25503.6504 Time elapsed: 0.15 min Epoch: 002/050 | Batch 000/469 | Cost: 25003.4570 Epoch: 002/050 | Batch 050/469 | Cost: 24463.9062 Epoch: 002/050 | Batch 100/469 | Cost: 24639.4102 Epoch: 002/050 | Batch 150/469 | Cost: 25084.6973 Epoch: 002/050 | Batch 200/469 | Cost: 24361.1699 Epoch: 002/050 | Batch 250/469 | Cost: 23067.3418 Epoch: 002/050 | Batch 300/469 | Cost: 23340.0000 Epoch: 002/050 | Batch 350/469 | Cost: 22182.6523 Epoch: 002/050 | Batch 400/469 | Cost: 23848.3574 Epoch: 002/050 | Batch 450/469 | Cost: 22587.1523 Time elapsed: 0.29 min Epoch: 003/050 | Batch 000/469 | Cost: 23024.4629 Epoch: 003/050 | Batch 050/469 | Cost: 22733.5449 Epoch: 003/050 | Batch 100/469 | Cost: 21335.1270 Epoch: 003/050 | Batch 150/469 | Cost: 22196.0801 Epoch: 003/050 | Batch 200/469 | Cost: 21967.5898 Epoch: 003/050 | Batch 250/469 | Cost: 21690.1719 Epoch: 003/050 | Batch 300/469 | Cost: 21659.4297 Epoch: 003/050 | Batch 350/469 | Cost: 20510.7891 Epoch: 003/050 | Batch 400/469 | Cost: 21242.7695 Epoch: 003/050 | Batch 450/469 | Cost: 20785.4922 Time elapsed: 0.44 min Epoch: 004/050 | Batch 000/469 | Cost: 20328.3145 Epoch: 004/050 | Batch 050/469 | Cost: 20273.4453 Epoch: 004/050 | Batch 100/469 | Cost: 20539.0879 Epoch: 004/050 | Batch 150/469 | Cost: 20137.0156 Epoch: 004/050 | Batch 200/469 | Cost: 19641.2148 Epoch: 004/050 | Batch 250/469 | Cost: 20138.8418 Epoch: 004/050 | Batch 300/469 | Cost: 18882.3086 Epoch: 004/050 | Batch 350/469 | Cost: 19263.8516 Epoch: 004/050 | Batch 400/469 | Cost: 19991.0703 Epoch: 004/050 | Batch 450/469 | Cost: 18806.1582 Time elapsed: 0.58 min Epoch: 005/050 | Batch 000/469 | Cost: 19555.3555 Epoch: 005/050 | Batch 050/469 | Cost: 19259.2910 Epoch: 005/050 | Batch 100/469 | Cost: 18361.7383 Epoch: 005/050 | Batch 150/469 | Cost: 18087.8281 Epoch: 005/050 | Batch 200/469 | Cost: 18091.0488 Epoch: 005/050 | Batch 250/469 | Cost: 19610.8906 Epoch: 005/050 | Batch 300/469 | Cost: 17971.9766 Epoch: 005/050 | Batch 350/469 | Cost: 18295.7988 Epoch: 005/050 | Batch 400/469 | Cost: 18726.7246 Epoch: 005/050 | Batch 450/469 | Cost: 17738.4590 Time elapsed: 0.73 min Epoch: 006/050 | Batch 000/469 | Cost: 18221.9043 Epoch: 006/050 | Batch 050/469 | Cost: 18209.2578 Epoch: 006/050 | Batch 100/469 | Cost: 17747.5586 Epoch: 006/050 | Batch 150/469 | Cost: 16824.4941 Epoch: 006/050 | Batch 200/469 | Cost: 17604.7949 Epoch: 006/050 | Batch 250/469 | Cost: 17186.1855 Epoch: 006/050 | Batch 300/469 | Cost: 17752.4570 Epoch: 006/050 | Batch 350/469 | Cost: 17146.1660 Epoch: 006/050 | Batch 400/469 | Cost: 17564.4121 Epoch: 006/050 | Batch 450/469 | Cost: 16982.0527 Time elapsed: 0.88 min Epoch: 007/050 | Batch 000/469 | Cost: 17229.5156 Epoch: 007/050 | Batch 050/469 | Cost: 17952.2988 Epoch: 007/050 | Batch 100/469 | Cost: 17679.9102 Epoch: 007/050 | Batch 150/469 | Cost: 16431.1602 Epoch: 007/050 | Batch 200/469 | Cost: 16707.6699 Epoch: 007/050 | Batch 250/469 | Cost: 16626.2344 Epoch: 007/050 | Batch 300/469 | Cost: 17091.8594 Epoch: 007/050 | Batch 350/469 | Cost: 16410.7461 Epoch: 007/050 | Batch 400/469 | Cost: 16464.0039 Epoch: 007/050 | Batch 450/469 | Cost: 17185.2910 Time elapsed: 1.03 min Epoch: 008/050 | Batch 000/469 | Cost: 16310.5146 Epoch: 008/050 | Batch 050/469 | Cost: 16510.9883 Epoch: 008/050 | Batch 100/469 | Cost: 16409.1504 Epoch: 008/050 | Batch 150/469 | Cost: 16645.9414 Epoch: 008/050 | Batch 200/469 | Cost: 16140.7637 Epoch: 008/050 | Batch 250/469 | Cost: 16261.8232 Epoch: 008/050 | Batch 300/469 | Cost: 15731.7832 Epoch: 008/050 | Batch 350/469 | Cost: 16438.5391 Epoch: 008/050 | Batch 400/469 | Cost: 16522.8516 Epoch: 008/050 | Batch 450/469 | Cost: 16674.2656 Time elapsed: 1.18 min Epoch: 009/050 | Batch 000/469 | Cost: 15663.1904 Epoch: 009/050 | Batch 050/469 | Cost: 15857.3770 Epoch: 009/050 | Batch 100/469 | Cost: 16233.4707 Epoch: 009/050 | Batch 150/469 | Cost: 16635.9785 Epoch: 009/050 | Batch 200/469 | Cost: 16294.5547 Epoch: 009/050 | Batch 250/469 | Cost: 15947.5801 Epoch: 009/050 | Batch 300/469 | Cost: 16139.6113 Epoch: 009/050 | Batch 350/469 | Cost: 16081.8906 Epoch: 009/050 | Batch 400/469 | Cost: 16331.2500 Epoch: 009/050 | Batch 450/469 | Cost: 15352.7773 Time elapsed: 1.32 min Epoch: 010/050 | Batch 000/469 | Cost: 15708.1094 Epoch: 010/050 | Batch 050/469 | Cost: 16146.4141 Epoch: 010/050 | Batch 100/469 | Cost: 16003.0078 Epoch: 010/050 | Batch 150/469 | Cost: 14990.0430 Epoch: 010/050 | Batch 200/469 | Cost: 15628.8242 Epoch: 010/050 | Batch 250/469 | Cost: 15949.3691 Epoch: 010/050 | Batch 300/469 | Cost: 14936.1875 Epoch: 010/050 | Batch 350/469 | Cost: 15196.0625 Epoch: 010/050 | Batch 400/469 | Cost: 15594.7686 Epoch: 010/050 | Batch 450/469 | Cost: 16577.6230 Time elapsed: 1.47 min Epoch: 011/050 | Batch 000/469 | Cost: 15626.6035 Epoch: 011/050 | Batch 050/469 | Cost: 15766.5859 Epoch: 011/050 | Batch 100/469 | Cost: 16134.7734 Epoch: 011/050 | Batch 150/469 | Cost: 15299.3574 Epoch: 011/050 | Batch 200/469 | Cost: 15611.9248 Epoch: 011/050 | Batch 250/469 | Cost: 16024.9580 Epoch: 011/050 | Batch 300/469 | Cost: 15840.3047 Epoch: 011/050 | Batch 350/469 | Cost: 15214.9883 Epoch: 011/050 | Batch 400/469 | Cost: 15782.8574 Epoch: 011/050 | Batch 450/469 | Cost: 15268.2646 Time elapsed: 1.62 min Epoch: 012/050 | Batch 000/469 | Cost: 15584.0322 Epoch: 012/050 | Batch 050/469 | Cost: 15752.2871 Epoch: 012/050 | Batch 100/469 | Cost: 15630.2695 Epoch: 012/050 | Batch 150/469 | Cost: 15513.7822 Epoch: 012/050 | Batch 200/469 | Cost: 15230.1543 Epoch: 012/050 | Batch 250/469 | Cost: 14979.2441 Epoch: 012/050 | Batch 300/469 | Cost: 15661.2822 Epoch: 012/050 | Batch 350/469 | Cost: 14583.6562 Epoch: 012/050 | Batch 400/469 | Cost: 15692.2637 Epoch: 012/050 | Batch 450/469 | Cost: 15444.3398 Time elapsed: 1.76 min Epoch: 013/050 | Batch 000/469 | Cost: 15403.8340 Epoch: 013/050 | Batch 050/469 | Cost: 15334.1270 Epoch: 013/050 | Batch 100/469 | Cost: 14869.9678 Epoch: 013/050 | Batch 150/469 | Cost: 14967.7734 Epoch: 013/050 | Batch 200/469 | Cost: 15038.3467 Epoch: 013/050 | Batch 250/469 | Cost: 14940.0566 Epoch: 013/050 | Batch 300/469 | Cost: 15861.9902 Epoch: 013/050 | Batch 350/469 | Cost: 15016.0215 Epoch: 013/050 | Batch 400/469 | Cost: 15020.5508 Epoch: 013/050 | Batch 450/469 | Cost: 14558.9678 Time elapsed: 1.91 min Epoch: 014/050 | Batch 000/469 | Cost: 15571.9609 Epoch: 014/050 | Batch 050/469 | Cost: 14986.8027 Epoch: 014/050 | Batch 100/469 | Cost: 14748.6660 Epoch: 014/050 | Batch 150/469 | Cost: 15177.5010 Epoch: 014/050 | Batch 200/469 | Cost: 15166.5283 Epoch: 014/050 | Batch 250/469 | Cost: 14866.0449 Epoch: 014/050 | Batch 300/469 | Cost: 15227.5977 Epoch: 014/050 | Batch 350/469 | Cost: 15148.1973 Epoch: 014/050 | Batch 400/469 | Cost: 15003.9395 Epoch: 014/050 | Batch 450/469 | Cost: 15571.9531 Time elapsed: 2.06 min Epoch: 015/050 | Batch 000/469 | Cost: 15100.7773 Epoch: 015/050 | Batch 050/469 | Cost: 14556.3730 Epoch: 015/050 | Batch 100/469 | Cost: 15114.8965 Epoch: 015/050 | Batch 150/469 | Cost: 15237.2412 Epoch: 015/050 | Batch 200/469 | Cost: 15173.2842 Epoch: 015/050 | Batch 250/469 | Cost: 15283.6016 Epoch: 015/050 | Batch 300/469 | Cost: 14717.9834 Epoch: 015/050 | Batch 350/469 | Cost: 15098.9512 Epoch: 015/050 | Batch 400/469 | Cost: 14483.3516 Epoch: 015/050 | Batch 450/469 | Cost: 14874.4346 Time elapsed: 2.21 min Epoch: 016/050 | Batch 000/469 | Cost: 14778.4883 Epoch: 016/050 | Batch 050/469 | Cost: 14571.5068 Epoch: 016/050 | Batch 100/469 | Cost: 14361.2773 Epoch: 016/050 | Batch 150/469 | Cost: 14580.1055 Epoch: 016/050 | Batch 200/469 | Cost: 14950.4766 Epoch: 016/050 | Batch 250/469 | Cost: 14357.0742 Epoch: 016/050 | Batch 300/469 | Cost: 15067.2119 Epoch: 016/050 | Batch 350/469 | Cost: 14431.0293 Epoch: 016/050 | Batch 400/469 | Cost: 15010.4941 Epoch: 016/050 | Batch 450/469 | Cost: 14981.4385 Time elapsed: 2.35 min Epoch: 017/050 | Batch 000/469 | Cost: 14213.7207 Epoch: 017/050 | Batch 050/469 | Cost: 14254.8223 Epoch: 017/050 | Batch 100/469 | Cost: 14608.7031 Epoch: 017/050 | Batch 150/469 | Cost: 14804.6738 Epoch: 017/050 | Batch 200/469 | Cost: 15223.3574 Epoch: 017/050 | Batch 250/469 | Cost: 15073.8105 Epoch: 017/050 | Batch 300/469 | Cost: 14488.2256 Epoch: 017/050 | Batch 350/469 | Cost: 15285.3438 Epoch: 017/050 | Batch 400/469 | Cost: 14768.0410 Epoch: 017/050 | Batch 450/469 | Cost: 14246.4082 Time elapsed: 2.49 min Epoch: 018/050 | Batch 000/469 | Cost: 14446.7607 Epoch: 018/050 | Batch 050/469 | Cost: 14307.9512 Epoch: 018/050 | Batch 100/469 | Cost: 14979.2393 Epoch: 018/050 | Batch 150/469 | Cost: 14640.7529 Epoch: 018/050 | Batch 200/469 | Cost: 14336.5176 Epoch: 018/050 | Batch 250/469 | Cost: 14856.0244 Epoch: 018/050 | Batch 300/469 | Cost: 14236.4883 Epoch: 018/050 | Batch 350/469 | Cost: 14293.7129 Epoch: 018/050 | Batch 400/469 | Cost: 14989.7578 Epoch: 018/050 | Batch 450/469 | Cost: 14645.5918 Time elapsed: 2.63 min Epoch: 019/050 | Batch 000/469 | Cost: 14769.7305 Epoch: 019/050 | Batch 050/469 | Cost: 14644.3301 Epoch: 019/050 | Batch 100/469 | Cost: 14153.6289 Epoch: 019/050 | Batch 150/469 | Cost: 15014.8457 Epoch: 019/050 | Batch 200/469 | Cost: 14531.8291 Epoch: 019/050 | Batch 250/469 | Cost: 14103.4414 Epoch: 019/050 | Batch 300/469 | Cost: 14499.4141 Epoch: 019/050 | Batch 350/469 | Cost: 14517.2227 Epoch: 019/050 | Batch 400/469 | Cost: 14708.0664 Epoch: 019/050 | Batch 450/469 | Cost: 14042.7529 Time elapsed: 2.77 min Epoch: 020/050 | Batch 000/469 | Cost: 15051.2266 Epoch: 020/050 | Batch 050/469 | Cost: 14537.1982 Epoch: 020/050 | Batch 100/469 | Cost: 13989.1104 Epoch: 020/050 | Batch 150/469 | Cost: 14822.6094 Epoch: 020/050 | Batch 200/469 | Cost: 15177.9668 Epoch: 020/050 | Batch 250/469 | Cost: 14710.3174 Epoch: 020/050 | Batch 300/469 | Cost: 13794.1641 Epoch: 020/050 | Batch 350/469 | Cost: 14262.4473 Epoch: 020/050 | Batch 400/469 | Cost: 14950.7432 Epoch: 020/050 | Batch 450/469 | Cost: 14864.3555 Time elapsed: 2.91 min Epoch: 021/050 | Batch 000/469 | Cost: 15020.9473 Epoch: 021/050 | Batch 050/469 | Cost: 14729.3340 Epoch: 021/050 | Batch 100/469 | Cost: 14100.7500 Epoch: 021/050 | Batch 150/469 | Cost: 14151.6641 Epoch: 021/050 | Batch 200/469 | Cost: 14153.0459 Epoch: 021/050 | Batch 250/469 | Cost: 14365.5645 Epoch: 021/050 | Batch 300/469 | Cost: 14539.5244 Epoch: 021/050 | Batch 350/469 | Cost: 14018.8398 Epoch: 021/050 | Batch 400/469 | Cost: 14032.9209 Epoch: 021/050 | Batch 450/469 | Cost: 13872.8320 Time elapsed: 3.06 min Epoch: 022/050 | Batch 000/469 | Cost: 14742.1719 Epoch: 022/050 | Batch 050/469 | Cost: 14320.2646 Epoch: 022/050 | Batch 100/469 | Cost: 14856.3320 Epoch: 022/050 | Batch 150/469 | Cost: 14376.0273 Epoch: 022/050 | Batch 200/469 | Cost: 14115.4121 Epoch: 022/050 | Batch 250/469 | Cost: 13767.6973 Epoch: 022/050 | Batch 300/469 | Cost: 13885.6768 Epoch: 022/050 | Batch 350/469 | Cost: 15135.5273 Epoch: 022/050 | Batch 400/469 | Cost: 14869.7598 Epoch: 022/050 | Batch 450/469 | Cost: 13792.0283 Time elapsed: 3.20 min Epoch: 023/050 | Batch 000/469 | Cost: 14404.2324 Epoch: 023/050 | Batch 050/469 | Cost: 14076.9844 Epoch: 023/050 | Batch 100/469 | Cost: 14239.1904 Epoch: 023/050 | Batch 150/469 | Cost: 14376.3242 Epoch: 023/050 | Batch 200/469 | Cost: 13941.0156 Epoch: 023/050 | Batch 250/469 | Cost: 13948.4395 Epoch: 023/050 | Batch 300/469 | Cost: 15119.5137 Epoch: 023/050 | Batch 350/469 | Cost: 14480.1211 Epoch: 023/050 | Batch 400/469 | Cost: 14310.3594 Epoch: 023/050 | Batch 450/469 | Cost: 14712.5039 Time elapsed: 3.34 min Epoch: 024/050 | Batch 000/469 | Cost: 14535.5488 Epoch: 024/050 | Batch 050/469 | Cost: 14241.1660 Epoch: 024/050 | Batch 100/469 | Cost: 14769.8477 Epoch: 024/050 | Batch 150/469 | Cost: 15056.7559 Epoch: 024/050 | Batch 200/469 | Cost: 14387.6484 Epoch: 024/050 | Batch 250/469 | Cost: 14316.7148 Epoch: 024/050 | Batch 300/469 | Cost: 14848.7793 Epoch: 024/050 | Batch 350/469 | Cost: 14909.2490 Epoch: 024/050 | Batch 400/469 | Cost: 14848.7090 Epoch: 024/050 | Batch 450/469 | Cost: 14461.7627 Time elapsed: 3.48 min Epoch: 025/050 | Batch 000/469 | Cost: 14212.7168 Epoch: 025/050 | Batch 050/469 | Cost: 14333.6973 Epoch: 025/050 | Batch 100/469 | Cost: 14074.0586 Epoch: 025/050 | Batch 150/469 | Cost: 14331.3789 Epoch: 025/050 | Batch 200/469 | Cost: 13657.7471 Epoch: 025/050 | Batch 250/469 | Cost: 14190.0117 Epoch: 025/050 | Batch 300/469 | Cost: 13733.5908 Epoch: 025/050 | Batch 350/469 | Cost: 14021.1602 Epoch: 025/050 | Batch 400/469 | Cost: 13840.4336 Epoch: 025/050 | Batch 450/469 | Cost: 14060.3848 Time elapsed: 3.63 min Epoch: 026/050 | Batch 000/469 | Cost: 15362.9629 Epoch: 026/050 | Batch 050/469 | Cost: 14140.0303 Epoch: 026/050 | Batch 100/469 | Cost: 13597.3838 Epoch: 026/050 | Batch 150/469 | Cost: 14821.4492 Epoch: 026/050 | Batch 200/469 | Cost: 14879.7930 Epoch: 026/050 | Batch 250/469 | Cost: 14080.9072 Epoch: 026/050 | Batch 300/469 | Cost: 14645.4023 Epoch: 026/050 | Batch 350/469 | Cost: 13696.6152 Epoch: 026/050 | Batch 400/469 | Cost: 14472.7656 Epoch: 026/050 | Batch 450/469 | Cost: 14059.6641 Time elapsed: 3.78 min Epoch: 027/050 | Batch 000/469 | Cost: 14369.2246 Epoch: 027/050 | Batch 050/469 | Cost: 13632.5137 Epoch: 027/050 | Batch 100/469 | Cost: 13472.9004 Epoch: 027/050 | Batch 150/469 | Cost: 13673.4121 Epoch: 027/050 | Batch 200/469 | Cost: 14124.0625 Epoch: 027/050 | Batch 250/469 | Cost: 13920.0332 Epoch: 027/050 | Batch 300/469 | Cost: 13909.5391 Epoch: 027/050 | Batch 350/469 | Cost: 14398.0977 Epoch: 027/050 | Batch 400/469 | Cost: 14438.4854 Epoch: 027/050 | Batch 450/469 | Cost: 14019.9814 Time elapsed: 3.93 min Epoch: 028/050 | Batch 000/469 | Cost: 14063.9189 Epoch: 028/050 | Batch 050/469 | Cost: 14298.8477 Epoch: 028/050 | Batch 100/469 | Cost: 13534.4980 Epoch: 028/050 | Batch 150/469 | Cost: 13799.8779 Epoch: 028/050 | Batch 200/469 | Cost: 13730.7334 Epoch: 028/050 | Batch 250/469 | Cost: 13006.5938 Epoch: 028/050 | Batch 300/469 | Cost: 14268.8652 Epoch: 028/050 | Batch 350/469 | Cost: 13673.4648 Epoch: 028/050 | Batch 400/469 | Cost: 13597.6719 Epoch: 028/050 | Batch 450/469 | Cost: 13925.3242 Time elapsed: 4.08 min Epoch: 029/050 | Batch 000/469 | Cost: 14032.7266 Epoch: 029/050 | Batch 050/469 | Cost: 14527.6777 Epoch: 029/050 | Batch 100/469 | Cost: 14219.7266 Epoch: 029/050 | Batch 150/469 | Cost: 13933.3320 Epoch: 029/050 | Batch 200/469 | Cost: 14406.4668 Epoch: 029/050 | Batch 250/469 | Cost: 13692.3379 Epoch: 029/050 | Batch 300/469 | Cost: 13557.2705 Epoch: 029/050 | Batch 350/469 | Cost: 14528.8633 Epoch: 029/050 | Batch 400/469 | Cost: 14413.3438 Epoch: 029/050 | Batch 450/469 | Cost: 14293.6504 Time elapsed: 4.23 min Epoch: 030/050 | Batch 000/469 | Cost: 14673.0938 Epoch: 030/050 | Batch 050/469 | Cost: 14199.3184 Epoch: 030/050 | Batch 100/469 | Cost: 14027.1729 Epoch: 030/050 | Batch 150/469 | Cost: 14117.5713 Epoch: 030/050 | Batch 200/469 | Cost: 13543.0605 Epoch: 030/050 | Batch 250/469 | Cost: 14418.0820 Epoch: 030/050 | Batch 300/469 | Cost: 13932.8691 Epoch: 030/050 | Batch 350/469 | Cost: 13475.8350 Epoch: 030/050 | Batch 400/469 | Cost: 14393.7646 Epoch: 030/050 | Batch 450/469 | Cost: 14195.9902 Time elapsed: 4.37 min Epoch: 031/050 | Batch 000/469 | Cost: 13865.0762 Epoch: 031/050 | Batch 050/469 | Cost: 13816.7061 Epoch: 031/050 | Batch 100/469 | Cost: 13752.8525 Epoch: 031/050 | Batch 150/469 | Cost: 14141.7930 Epoch: 031/050 | Batch 200/469 | Cost: 14415.1172 Epoch: 031/050 | Batch 250/469 | Cost: 13907.3770 Epoch: 031/050 | Batch 300/469 | Cost: 13910.6807 Epoch: 031/050 | Batch 350/469 | Cost: 13633.5596 Epoch: 031/050 | Batch 400/469 | Cost: 13621.3359 Epoch: 031/050 | Batch 450/469 | Cost: 13538.8291 Time elapsed: 4.52 min Epoch: 032/050 | Batch 000/469 | Cost: 14009.0742 Epoch: 032/050 | Batch 050/469 | Cost: 13491.7461 Epoch: 032/050 | Batch 100/469 | Cost: 13270.1104 Epoch: 032/050 | Batch 150/469 | Cost: 14276.8320 Epoch: 032/050 | Batch 200/469 | Cost: 13928.1875 Epoch: 032/050 | Batch 250/469 | Cost: 13973.2520 Epoch: 032/050 | Batch 300/469 | Cost: 14112.7969 Epoch: 032/050 | Batch 350/469 | Cost: 14247.1250 Epoch: 032/050 | Batch 400/469 | Cost: 14020.4355 Epoch: 032/050 | Batch 450/469 | Cost: 13671.0029 Time elapsed: 4.67 min Epoch: 033/050 | Batch 000/469 | Cost: 14114.7676 Epoch: 033/050 | Batch 050/469 | Cost: 14096.6172 Epoch: 033/050 | Batch 100/469 | Cost: 14510.5137 Epoch: 033/050 | Batch 150/469 | Cost: 14087.4746 Epoch: 033/050 | Batch 200/469 | Cost: 13874.9834 Epoch: 033/050 | Batch 250/469 | Cost: 14145.5840 Epoch: 033/050 | Batch 300/469 | Cost: 13861.3926 Epoch: 033/050 | Batch 350/469 | Cost: 14629.8486 Epoch: 033/050 | Batch 400/469 | Cost: 14538.3857 Epoch: 033/050 | Batch 450/469 | Cost: 13830.5381 Time elapsed: 4.82 min Epoch: 034/050 | Batch 000/469 | Cost: 13836.5195 Epoch: 034/050 | Batch 050/469 | Cost: 13860.2246 Epoch: 034/050 | Batch 100/469 | Cost: 14087.6016 Epoch: 034/050 | Batch 150/469 | Cost: 14019.4785 Epoch: 034/050 | Batch 200/469 | Cost: 13451.0508 Epoch: 034/050 | Batch 250/469 | Cost: 13142.4326 Epoch: 034/050 | Batch 300/469 | Cost: 14079.7734 Epoch: 034/050 | Batch 350/469 | Cost: 13413.0859 Epoch: 034/050 | Batch 400/469 | Cost: 14405.9668 Epoch: 034/050 | Batch 450/469 | Cost: 14408.2139 Time elapsed: 4.97 min Epoch: 035/050 | Batch 000/469 | Cost: 13902.5938 Epoch: 035/050 | Batch 050/469 | Cost: 13920.2412 Epoch: 035/050 | Batch 100/469 | Cost: 13912.0137 Epoch: 035/050 | Batch 150/469 | Cost: 13720.4482 Epoch: 035/050 | Batch 200/469 | Cost: 13858.9121 Epoch: 035/050 | Batch 250/469 | Cost: 13355.0986 Epoch: 035/050 | Batch 300/469 | Cost: 13733.6855 Epoch: 035/050 | Batch 350/469 | Cost: 14387.2490 Epoch: 035/050 | Batch 400/469 | Cost: 14289.1094 Epoch: 035/050 | Batch 450/469 | Cost: 13157.4883 Time elapsed: 5.11 min Epoch: 036/050 | Batch 000/469 | Cost: 13923.4131 Epoch: 036/050 | Batch 050/469 | Cost: 13152.2998 Epoch: 036/050 | Batch 100/469 | Cost: 13996.1729 Epoch: 036/050 | Batch 150/469 | Cost: 13884.8965 Epoch: 036/050 | Batch 200/469 | Cost: 13887.7607 Epoch: 036/050 | Batch 250/469 | Cost: 13652.5996 Epoch: 036/050 | Batch 300/469 | Cost: 13951.4346 Epoch: 036/050 | Batch 350/469 | Cost: 13787.7617 Epoch: 036/050 | Batch 400/469 | Cost: 14097.5078 Epoch: 036/050 | Batch 450/469 | Cost: 13684.4854 Time elapsed: 5.26 min Epoch: 037/050 | Batch 000/469 | Cost: 14580.7109 Epoch: 037/050 | Batch 050/469 | Cost: 13706.5557 Epoch: 037/050 | Batch 100/469 | Cost: 14079.7070 Epoch: 037/050 | Batch 150/469 | Cost: 14231.3975 Epoch: 037/050 | Batch 200/469 | Cost: 13724.7275 Epoch: 037/050 | Batch 250/469 | Cost: 14127.0488 Epoch: 037/050 | Batch 300/469 | Cost: 14432.3828 Epoch: 037/050 | Batch 350/469 | Cost: 13770.9668 Epoch: 037/050 | Batch 400/469 | Cost: 14457.6172 Epoch: 037/050 | Batch 450/469 | Cost: 13425.8623 Time elapsed: 5.41 min Epoch: 038/050 | Batch 000/469 | Cost: 13763.5371 Epoch: 038/050 | Batch 050/469 | Cost: 13891.8945 Epoch: 038/050 | Batch 100/469 | Cost: 13626.1357 Epoch: 038/050 | Batch 150/469 | Cost: 14679.0449 Epoch: 038/050 | Batch 200/469 | Cost: 13221.4004 Epoch: 038/050 | Batch 250/469 | Cost: 13140.2148 Epoch: 038/050 | Batch 300/469 | Cost: 13809.6084 Epoch: 038/050 | Batch 350/469 | Cost: 13575.6592 Epoch: 038/050 | Batch 400/469 | Cost: 14249.9180 Epoch: 038/050 | Batch 450/469 | Cost: 14097.8291 Time elapsed: 5.56 min Epoch: 039/050 | Batch 000/469 | Cost: 14015.1768 Epoch: 039/050 | Batch 050/469 | Cost: 13973.9795 Epoch: 039/050 | Batch 100/469 | Cost: 13633.8730 Epoch: 039/050 | Batch 150/469 | Cost: 14055.6895 Epoch: 039/050 | Batch 200/469 | Cost: 13871.2949 Epoch: 039/050 | Batch 250/469 | Cost: 13746.9258 Epoch: 039/050 | Batch 300/469 | Cost: 13203.3242 Epoch: 039/050 | Batch 350/469 | Cost: 13911.6846 Epoch: 039/050 | Batch 400/469 | Cost: 14241.5703 Epoch: 039/050 | Batch 450/469 | Cost: 13677.2559 Time elapsed: 5.70 min Epoch: 040/050 | Batch 000/469 | Cost: 14490.0547 Epoch: 040/050 | Batch 050/469 | Cost: 13689.6680 Epoch: 040/050 | Batch 100/469 | Cost: 14046.6895 Epoch: 040/050 | Batch 150/469 | Cost: 13632.8125 Epoch: 040/050 | Batch 200/469 | Cost: 13456.0918 Epoch: 040/050 | Batch 250/469 | Cost: 13832.4795 Epoch: 040/050 | Batch 300/469 | Cost: 13813.2939 Epoch: 040/050 | Batch 350/469 | Cost: 13484.2520 Epoch: 040/050 | Batch 400/469 | Cost: 13600.7803 Epoch: 040/050 | Batch 450/469 | Cost: 13492.7578 Time elapsed: 5.85 min Epoch: 041/050 | Batch 000/469 | Cost: 13993.5547 Epoch: 041/050 | Batch 050/469 | Cost: 13833.7031 Epoch: 041/050 | Batch 100/469 | Cost: 13798.5264 Epoch: 041/050 | Batch 150/469 | Cost: 14379.4717 Epoch: 041/050 | Batch 200/469 | Cost: 13919.1445 Epoch: 041/050 | Batch 250/469 | Cost: 13361.4160 Epoch: 041/050 | Batch 300/469 | Cost: 14154.9043 Epoch: 041/050 | Batch 350/469 | Cost: 13858.2715 Epoch: 041/050 | Batch 400/469 | Cost: 14078.7451 Epoch: 041/050 | Batch 450/469 | Cost: 13970.0488 Time elapsed: 6.00 min Epoch: 042/050 | Batch 000/469 | Cost: 14093.0371 Epoch: 042/050 | Batch 050/469 | Cost: 14073.4688 Epoch: 042/050 | Batch 100/469 | Cost: 13645.2754 Epoch: 042/050 | Batch 150/469 | Cost: 13464.0029 Epoch: 042/050 | Batch 200/469 | Cost: 13615.8643 Epoch: 042/050 | Batch 250/469 | Cost: 13301.9805 Epoch: 042/050 | Batch 300/469 | Cost: 13605.0020 Epoch: 042/050 | Batch 350/469 | Cost: 14035.0498 Epoch: 042/050 | Batch 400/469 | Cost: 13637.4297 Epoch: 042/050 | Batch 450/469 | Cost: 14165.7686 Time elapsed: 6.15 min Epoch: 043/050 | Batch 000/469 | Cost: 13715.1055 Epoch: 043/050 | Batch 050/469 | Cost: 14122.5898 Epoch: 043/050 | Batch 100/469 | Cost: 14184.3633 Epoch: 043/050 | Batch 150/469 | Cost: 13745.1133 Epoch: 043/050 | Batch 200/469 | Cost: 13448.2559 Epoch: 043/050 | Batch 250/469 | Cost: 13323.3438 Epoch: 043/050 | Batch 300/469 | Cost: 13835.5723 Epoch: 043/050 | Batch 350/469 | Cost: 13462.5098 Epoch: 043/050 | Batch 400/469 | Cost: 14195.2227 Epoch: 043/050 | Batch 450/469 | Cost: 13253.4600 Time elapsed: 6.30 min Epoch: 044/050 | Batch 000/469 | Cost: 14028.9277 Epoch: 044/050 | Batch 050/469 | Cost: 13369.4111 Epoch: 044/050 | Batch 100/469 | Cost: 13645.9971 Epoch: 044/050 | Batch 150/469 | Cost: 13864.8613 Epoch: 044/050 | Batch 200/469 | Cost: 13508.7471 Epoch: 044/050 | Batch 250/469 | Cost: 14534.7754 Epoch: 044/050 | Batch 300/469 | Cost: 13565.7900 Epoch: 044/050 | Batch 350/469 | Cost: 13719.3438 Epoch: 044/050 | Batch 400/469 | Cost: 13678.1367 Epoch: 044/050 | Batch 450/469 | Cost: 14057.3779 Time elapsed: 6.44 min Epoch: 045/050 | Batch 000/469 | Cost: 13414.4121 Epoch: 045/050 | Batch 050/469 | Cost: 13531.8555 Epoch: 045/050 | Batch 100/469 | Cost: 13470.2266 Epoch: 045/050 | Batch 150/469 | Cost: 13866.7627 Epoch: 045/050 | Batch 200/469 | Cost: 13438.2832 Epoch: 045/050 | Batch 250/469 | Cost: 14194.3691 Epoch: 045/050 | Batch 300/469 | Cost: 14172.3320 Epoch: 045/050 | Batch 350/469 | Cost: 13798.1680 Epoch: 045/050 | Batch 400/469 | Cost: 13684.1064 Epoch: 045/050 | Batch 450/469 | Cost: 13255.7441 Time elapsed: 6.59 min Epoch: 046/050 | Batch 000/469 | Cost: 13833.5371 Epoch: 046/050 | Batch 050/469 | Cost: 13982.0898 Epoch: 046/050 | Batch 100/469 | Cost: 13699.0674 Epoch: 046/050 | Batch 150/469 | Cost: 13579.7803 Epoch: 046/050 | Batch 200/469 | Cost: 13611.3682 Epoch: 046/050 | Batch 250/469 | Cost: 14532.4092 Epoch: 046/050 | Batch 300/469 | Cost: 13690.0381 Epoch: 046/050 | Batch 350/469 | Cost: 13886.2227 Epoch: 046/050 | Batch 400/469 | Cost: 13716.4883 Epoch: 046/050 | Batch 450/469 | Cost: 13887.5723 Time elapsed: 6.74 min Epoch: 047/050 | Batch 000/469 | Cost: 13460.0312 Epoch: 047/050 | Batch 050/469 | Cost: 13862.8320 Epoch: 047/050 | Batch 100/469 | Cost: 13045.7754 Epoch: 047/050 | Batch 150/469 | Cost: 13520.7910 Epoch: 047/050 | Batch 200/469 | Cost: 13966.8848 Epoch: 047/050 | Batch 250/469 | Cost: 14337.5615 Epoch: 047/050 | Batch 300/469 | Cost: 13835.9805 Epoch: 047/050 | Batch 350/469 | Cost: 13705.6699 Epoch: 047/050 | Batch 400/469 | Cost: 14085.0215 Epoch: 047/050 | Batch 450/469 | Cost: 13708.9961 Time elapsed: 6.89 min Epoch: 048/050 | Batch 000/469 | Cost: 13683.8477 Epoch: 048/050 | Batch 050/469 | Cost: 14290.2441 Epoch: 048/050 | Batch 100/469 | Cost: 13824.9033 Epoch: 048/050 | Batch 150/469 | Cost: 13902.4424 Epoch: 048/050 | Batch 200/469 | Cost: 13742.8066 Epoch: 048/050 | Batch 250/469 | Cost: 13804.6270 Epoch: 048/050 | Batch 300/469 | Cost: 14011.4414 Epoch: 048/050 | Batch 350/469 | Cost: 13902.3428 Epoch: 048/050 | Batch 400/469 | Cost: 13671.2607 Epoch: 048/050 | Batch 450/469 | Cost: 13533.4326 Time elapsed: 7.03 min Epoch: 049/050 | Batch 000/469 | Cost: 13808.8584 Epoch: 049/050 | Batch 050/469 | Cost: 14385.1328 Epoch: 049/050 | Batch 100/469 | Cost: 13595.7334 Epoch: 049/050 | Batch 150/469 | Cost: 13449.9658 Epoch: 049/050 | Batch 200/469 | Cost: 13782.0635 Epoch: 049/050 | Batch 250/469 | Cost: 13681.0293 Epoch: 049/050 | Batch 300/469 | Cost: 14259.2988 Epoch: 049/050 | Batch 350/469 | Cost: 13350.5176 Epoch: 049/050 | Batch 400/469 | Cost: 12788.5156 Epoch: 049/050 | Batch 450/469 | Cost: 13642.1787 Time elapsed: 7.18 min Epoch: 050/050 | Batch 000/469 | Cost: 13596.1172 Epoch: 050/050 | Batch 050/469 | Cost: 13988.5371 Epoch: 050/050 | Batch 100/469 | Cost: 14061.5742 Epoch: 050/050 | Batch 150/469 | Cost: 13996.9111 Epoch: 050/050 | Batch 200/469 | Cost: 13628.2070 Epoch: 050/050 | Batch 250/469 | Cost: 13667.3203 Epoch: 050/050 | Batch 300/469 | Cost: 13978.5820 Epoch: 050/050 | Batch 350/469 | Cost: 13589.2910 Epoch: 050/050 | Batch 400/469 | Cost: 13307.0566 Epoch: 050/050 | Batch 450/469 | Cost: 13997.9141 Time elapsed: 7.33 min Total Training Time: 7.33 min
%matplotlib inline
import matplotlib.pyplot as plt
##########################
### VISUALIZATION
##########################
n_images = 15
image_width = 28
fig, axes = plt.subplots(nrows=2, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images, 0]
for i in range(n_images):
for ax, img in zip(axes, [orig_images, decoded_images]):
cpu_img = img[i].detach().to(torch.device('cpu'))
ax[i].imshow(cpu_img.view((image_width, image_width)), cmap='binary')
for i in range(10):
##########################
### RANDOM SAMPLE
##########################
labels = torch.tensor([i]*10).to(device)
n_images = labels.size()[0]
rand_features = torch.randn(n_images, num_latent).to(device)
new_images = model.decoder(rand_features, labels)
##########################
### VISUALIZATION
##########################
image_width = 28
fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
decoded_images = new_images[:n_images, 0]
print('Class Label %d' % i)
for ax, img in zip(axes, decoded_images):
cpu_img = img.detach().to(torch.device('cpu'))
ax.imshow(cpu_img.view((image_width, image_width)), cmap='binary')
plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
watermark -iv
numpy 1.15.4 torch 1.0.0