Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.6.8 IPython 7.2.0 torch 1.0.0
A simple convolutional conditional variational autoencoder that compresses 768-pixel MNIST images down to a 50-pixel latent vector representation.
This implementation concatenates the inputs with the class labels when computing the reconstruction loss as it is commonly done in non-convolutional conditional variational autoencoders. This leads to substantially poorer results compared to the implementation that does NOT concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation ./autoencoder-cnn-cvae_no-out-concat.ipynb
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
##########################
### SETTINGS
##########################
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
# Hyperparameters
random_seed = 0
learning_rate = 0.001
num_epochs = 50
batch_size = 128
# Architecture
num_classes = 10
num_features = 784
num_latent = 50
##########################
### MNIST DATASET
##########################
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Device: cuda:0 Image batch dimensions: torch.Size([128, 1, 28, 28]) Image label dimensions: torch.Size([128])
##########################
### MODEL
##########################
def to_onehot(labels, num_classes, device):
labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
labels_onehot.scatter_(1, labels.view(-1, 1), 1)
return labels_onehot
class ConditionalVariationalAutoencoder(torch.nn.Module):
def __init__(self, num_features, num_latent, num_classes):
super(ConditionalVariationalAutoencoder, self).__init__()
self.num_classes = num_classes
###############
# ENCODER
##############
# calculate same padding:
# (w - k + 2*p)/s + 1 = o
# => p = (s(o-1) - w + k)/2
self.enc_conv_1 = torch.nn.Conv2d(in_channels=1+self.num_classes,
out_channels=16,
kernel_size=(6, 6),
stride=(2, 2),
padding=0)
self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=(4, 4),
stride=(2, 2),
padding=0)
self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=(2, 2),
stride=(2, 2),
padding=0)
self.z_mean = torch.nn.Linear(64*2*2, num_latent)
# in the original paper (Kingma & Welling 2015, we use
# have a z_mean and z_var, but the problem is that
# the z_var can be negative, which would cause issues
# in the log later. Hence we assume that latent vector
# has a z_mean and z_log_var component, and when we need
# the regular variance or std_dev, we simply use
# an exponential function
self.z_log_var = torch.nn.Linear(64*2*2, num_latent)
###############
# DECODER
##############
self.dec_linear_1 = torch.nn.Linear(num_latent+self.num_classes, 64*2*2)
self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,
out_channels=32,
kernel_size=(2, 2),
stride=(2, 2),
padding=0)
self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,
out_channels=16,
kernel_size=(4, 4),
stride=(3, 3),
padding=1)
self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,
out_channels=11,
kernel_size=(6, 6),
stride=(3, 3),
padding=4)
def reparameterize(self, z_mu, z_log_var):
# Sample epsilon from standard normal distribution
eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
# note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
# i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
z = z_mu + eps * torch.exp(z_log_var/2.)
return z
def encoder(self, features, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
onehot_targets = onehot_targets.view(-1, self.num_classes, 1, 1)
ones = torch.ones(features.size()[0],
self.num_classes,
features.size()[2],
features.size()[3],
dtype=features.dtype).to(device)
ones = ones * onehot_targets
x = torch.cat((features, ones), dim=1)
x = self.enc_conv_1(x)
x = F.leaky_relu(x)
#print('conv1 out:', x.size())
x = self.enc_conv_2(x)
x = F.leaky_relu(x)
#print('conv2 out:', x.size())
x = self.enc_conv_3(x)
x = F.leaky_relu(x)
#print('conv3 out:', x.size())
z_mean = self.z_mean(x.view(-1, 64*2*2))
z_log_var = self.z_log_var(x.view(-1, 64*2*2))
encoded = self.reparameterize(z_mean, z_log_var)
return z_mean, z_log_var, encoded
def decoder(self, encoded, targets):
### Add condition
onehot_targets = to_onehot(targets, self.num_classes, device)
encoded = torch.cat((encoded, onehot_targets), dim=1)
x = self.dec_linear_1(encoded)
x = x.view(-1, 64, 2, 2)
x = self.dec_deconv_1(x)
x = F.leaky_relu(x)
#print('deconv1 out:', x.size())
x = self.dec_deconv_2(x)
x = F.leaky_relu(x)
#print('deconv2 out:', x.size())
x = self.dec_deconv_3(x)
x = F.leaky_relu(x)
#print('deconv1 out:', x.size())
decoded = torch.sigmoid(x)
return decoded
def forward(self, features, targets):
z_mean, z_log_var, encoded = self.encoder(features, targets)
decoded = self.decoder(encoded, targets)
return z_mean, z_log_var, encoded, decoded
torch.manual_seed(random_seed)
model = ConditionalVariationalAutoencoder(num_features,
num_latent,
num_classes)
model = model.to(device)
##########################
### COST AND OPTIMIZER
##########################
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = time.time()
for epoch in range(num_epochs):
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(device)
targets = targets.to(device)
### FORWARD AND BACK PROP
z_mean, z_log_var, encoded, decoded = model(features, targets)
# cost = reconstruction loss + Kullback-Leibler divergence
kl_divergence = (0.5 * (z_mean**2 +
torch.exp(z_log_var) - z_log_var - 1)).sum()
### Add condition
onehot_targets = to_onehot(targets, num_classes, device)
onehot_targets = onehot_targets.view(-1, num_classes, 1, 1)
ones = torch.ones(features.size()[0],
num_classes,
features.size()[2],
features.size()[3],
dtype=features.dtype).to(device)
ones = ones * onehot_targets
x_con = torch.cat((features, ones), dim=1)
### Compute loss
pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
cost = kl_divergence + pixelwise_bce
### UPDATE MODEL PARAMETERS
optimizer.zero_grad()
cost.backward()
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_loader), cost))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/050 | Batch 000/469 | Cost: 767976.6875 Epoch: 001/050 | Batch 050/469 | Cost: 596610.6250 Epoch: 001/050 | Batch 100/469 | Cost: 373679.2188 Epoch: 001/050 | Batch 150/469 | Cost: 369423.0312 Epoch: 001/050 | Batch 200/469 | Cost: 372999.6250 Epoch: 001/050 | Batch 250/469 | Cost: 369966.6562 Epoch: 001/050 | Batch 300/469 | Cost: 370309.6250 Epoch: 001/050 | Batch 350/469 | Cost: 368235.9375 Epoch: 001/050 | Batch 400/469 | Cost: 369051.5312 Epoch: 001/050 | Batch 450/469 | Cost: 369094.0000 Time elapsed: 0.16 min Epoch: 002/050 | Batch 000/469 | Cost: 368791.1875 Epoch: 002/050 | Batch 050/469 | Cost: 367759.2500 Epoch: 002/050 | Batch 100/469 | Cost: 370258.9375 Epoch: 002/050 | Batch 150/469 | Cost: 367393.0938 Epoch: 002/050 | Batch 200/469 | Cost: 367319.2812 Epoch: 002/050 | Batch 250/469 | Cost: 366416.8125 Epoch: 002/050 | Batch 300/469 | Cost: 366461.8125 Epoch: 002/050 | Batch 350/469 | Cost: 369242.8125 Epoch: 002/050 | Batch 400/469 | Cost: 364756.9688 Epoch: 002/050 | Batch 450/469 | Cost: 366672.1562 Time elapsed: 0.31 min Epoch: 003/050 | Batch 000/469 | Cost: 365496.8750 Epoch: 003/050 | Batch 050/469 | Cost: 364285.8125 Epoch: 003/050 | Batch 100/469 | Cost: 360785.4062 Epoch: 003/050 | Batch 150/469 | Cost: 359350.1562 Epoch: 003/050 | Batch 200/469 | Cost: 360445.7500 Epoch: 003/050 | Batch 250/469 | Cost: 358670.5938 Epoch: 003/050 | Batch 300/469 | Cost: 357442.0938 Epoch: 003/050 | Batch 350/469 | Cost: 355264.8438 Epoch: 003/050 | Batch 400/469 | Cost: 354287.0000 Epoch: 003/050 | Batch 450/469 | Cost: 355108.7500 Time elapsed: 0.47 min Epoch: 004/050 | Batch 000/469 | Cost: 356378.7812 Epoch: 004/050 | Batch 050/469 | Cost: 353450.8125 Epoch: 004/050 | Batch 100/469 | Cost: 355565.7812 Epoch: 004/050 | Batch 150/469 | Cost: 352293.7188 Epoch: 004/050 | Batch 200/469 | Cost: 352507.3438 Epoch: 004/050 | Batch 250/469 | Cost: 349660.2812 Epoch: 004/050 | Batch 300/469 | Cost: 325723.5312 Epoch: 004/050 | Batch 350/469 | Cost: 321628.3438 Epoch: 004/050 | Batch 400/469 | Cost: 299635.0000 Epoch: 004/050 | Batch 450/469 | Cost: 254649.7344 Time elapsed: 0.63 min Epoch: 005/050 | Batch 000/469 | Cost: 246473.9219 Epoch: 005/050 | Batch 050/469 | Cost: 253253.3438 Epoch: 005/050 | Batch 100/469 | Cost: 229721.2812 Epoch: 005/050 | Batch 150/469 | Cost: 194743.3750 Epoch: 005/050 | Batch 200/469 | Cost: 206577.7812 Epoch: 005/050 | Batch 250/469 | Cost: 190684.9531 Epoch: 005/050 | Batch 300/469 | Cost: 205158.5781 Epoch: 005/050 | Batch 350/469 | Cost: 191114.2812 Epoch: 005/050 | Batch 400/469 | Cost: 162304.3125 Epoch: 005/050 | Batch 450/469 | Cost: 121093.5156 Time elapsed: 0.78 min Epoch: 006/050 | Batch 000/469 | Cost: 114467.9297 Epoch: 006/050 | Batch 050/469 | Cost: 96826.9844 Epoch: 006/050 | Batch 100/469 | Cost: 91189.8281 Epoch: 006/050 | Batch 150/469 | Cost: 69205.9062 Epoch: 006/050 | Batch 200/469 | Cost: 69573.7969 Epoch: 006/050 | Batch 250/469 | Cost: 58245.7578 Epoch: 006/050 | Batch 300/469 | Cost: 52482.0312 Epoch: 006/050 | Batch 350/469 | Cost: 49466.5234 Epoch: 006/050 | Batch 400/469 | Cost: 47806.3906 Epoch: 006/050 | Batch 450/469 | Cost: 47156.6328 Time elapsed: 0.94 min Epoch: 007/050 | Batch 000/469 | Cost: 46200.7422 Epoch: 007/050 | Batch 050/469 | Cost: 43992.4375 Epoch: 007/050 | Batch 100/469 | Cost: 45221.8672 Epoch: 007/050 | Batch 150/469 | Cost: 44009.6562 Epoch: 007/050 | Batch 200/469 | Cost: 42143.8555 Epoch: 007/050 | Batch 250/469 | Cost: 41778.6367 Epoch: 007/050 | Batch 300/469 | Cost: 39909.2383 Epoch: 007/050 | Batch 350/469 | Cost: 116561.6953 Epoch: 007/050 | Batch 400/469 | Cost: 40465.0156 Epoch: 007/050 | Batch 450/469 | Cost: 39080.0469 Time elapsed: 1.10 min Epoch: 008/050 | Batch 000/469 | Cost: 37986.4492 Epoch: 008/050 | Batch 050/469 | Cost: 38122.7812 Epoch: 008/050 | Batch 100/469 | Cost: 37492.6172 Epoch: 008/050 | Batch 150/469 | Cost: 36924.3398 Epoch: 008/050 | Batch 200/469 | Cost: 36508.1406 Epoch: 008/050 | Batch 250/469 | Cost: 36034.1172 Epoch: 008/050 | Batch 300/469 | Cost: 35373.1719 Epoch: 008/050 | Batch 350/469 | Cost: 36094.9648 Epoch: 008/050 | Batch 400/469 | Cost: 35194.5547 Epoch: 008/050 | Batch 450/469 | Cost: 35357.0078 Time elapsed: 1.25 min Epoch: 009/050 | Batch 000/469 | Cost: 34139.3164 Epoch: 009/050 | Batch 050/469 | Cost: 32684.7852 Epoch: 009/050 | Batch 100/469 | Cost: 33087.8008 Epoch: 009/050 | Batch 150/469 | Cost: 34713.3516 Epoch: 009/050 | Batch 200/469 | Cost: 34058.9414 Epoch: 009/050 | Batch 250/469 | Cost: 34817.7734 Epoch: 009/050 | Batch 300/469 | Cost: 32921.8438 Epoch: 009/050 | Batch 350/469 | Cost: 32419.2344 Epoch: 009/050 | Batch 400/469 | Cost: 32464.9297 Epoch: 009/050 | Batch 450/469 | Cost: 33157.8555 Time elapsed: 1.41 min Epoch: 010/050 | Batch 000/469 | Cost: 33069.7344 Epoch: 010/050 | Batch 050/469 | Cost: 32903.2422 Epoch: 010/050 | Batch 100/469 | Cost: 33106.3438 Epoch: 010/050 | Batch 150/469 | Cost: 31759.2578 Epoch: 010/050 | Batch 200/469 | Cost: 32032.4941 Epoch: 010/050 | Batch 250/469 | Cost: 30970.5547 Epoch: 010/050 | Batch 300/469 | Cost: 32380.6895 Epoch: 010/050 | Batch 350/469 | Cost: 31950.4258 Epoch: 010/050 | Batch 400/469 | Cost: 29926.0078 Epoch: 010/050 | Batch 450/469 | Cost: 30842.2812 Time elapsed: 1.57 min Epoch: 011/050 | Batch 000/469 | Cost: 30815.4336 Epoch: 011/050 | Batch 050/469 | Cost: 33126.0547 Epoch: 011/050 | Batch 100/469 | Cost: 32505.0547 Epoch: 011/050 | Batch 150/469 | Cost: 30127.3398 Epoch: 011/050 | Batch 200/469 | Cost: 30061.9590 Epoch: 011/050 | Batch 250/469 | Cost: 30604.7930 Epoch: 011/050 | Batch 300/469 | Cost: 29668.6602 Epoch: 011/050 | Batch 350/469 | Cost: 30926.6719 Epoch: 011/050 | Batch 400/469 | Cost: 30350.8242 Epoch: 011/050 | Batch 450/469 | Cost: 30159.5273 Time elapsed: 1.73 min Epoch: 012/050 | Batch 000/469 | Cost: 30400.7031 Epoch: 012/050 | Batch 050/469 | Cost: 30232.4375 Epoch: 012/050 | Batch 100/469 | Cost: 29833.6797 Epoch: 012/050 | Batch 150/469 | Cost: 30118.9746 Epoch: 012/050 | Batch 200/469 | Cost: 30293.8867 Epoch: 012/050 | Batch 250/469 | Cost: 29551.5000 Epoch: 012/050 | Batch 300/469 | Cost: 30271.3555 Epoch: 012/050 | Batch 350/469 | Cost: 29931.2773 Epoch: 012/050 | Batch 400/469 | Cost: 29358.0820 Epoch: 012/050 | Batch 450/469 | Cost: 30744.5586 Time elapsed: 1.88 min Epoch: 013/050 | Batch 000/469 | Cost: 29265.8965 Epoch: 013/050 | Batch 050/469 | Cost: 27530.9746 Epoch: 013/050 | Batch 100/469 | Cost: 28721.8672 Epoch: 013/050 | Batch 150/469 | Cost: 30155.9512 Epoch: 013/050 | Batch 200/469 | Cost: 29030.8125 Epoch: 013/050 | Batch 250/469 | Cost: 29201.0938 Epoch: 013/050 | Batch 300/469 | Cost: 28295.7402 Epoch: 013/050 | Batch 350/469 | Cost: 28157.3828 Epoch: 013/050 | Batch 400/469 | Cost: 28565.1133 Epoch: 013/050 | Batch 450/469 | Cost: 28998.5449 Time elapsed: 2.04 min Epoch: 014/050 | Batch 000/469 | Cost: 28962.7031 Epoch: 014/050 | Batch 050/469 | Cost: 28651.3438 Epoch: 014/050 | Batch 100/469 | Cost: 29397.2930 Epoch: 014/050 | Batch 150/469 | Cost: 29416.5078 Epoch: 014/050 | Batch 200/469 | Cost: 28977.9805 Epoch: 014/050 | Batch 250/469 | Cost: 29161.6523 Epoch: 014/050 | Batch 300/469 | Cost: 28904.8867 Epoch: 014/050 | Batch 350/469 | Cost: 26424.5078 Epoch: 014/050 | Batch 400/469 | Cost: 27135.6367 Epoch: 014/050 | Batch 450/469 | Cost: 27612.0020 Time elapsed: 2.20 min Epoch: 015/050 | Batch 000/469 | Cost: 28140.8086 Epoch: 015/050 | Batch 050/469 | Cost: 29116.8887 Epoch: 015/050 | Batch 100/469 | Cost: 28442.5781 Epoch: 015/050 | Batch 150/469 | Cost: 28238.6250 Epoch: 015/050 | Batch 200/469 | Cost: 27482.3203 Epoch: 015/050 | Batch 250/469 | Cost: 28634.3496 Epoch: 015/050 | Batch 300/469 | Cost: 26978.4004 Epoch: 015/050 | Batch 350/469 | Cost: 29071.4707 Epoch: 015/050 | Batch 400/469 | Cost: 27510.0801 Epoch: 015/050 | Batch 450/469 | Cost: 27338.2227 Time elapsed: 2.36 min Epoch: 016/050 | Batch 000/469 | Cost: 26819.6582 Epoch: 016/050 | Batch 050/469 | Cost: 28451.7012 Epoch: 016/050 | Batch 100/469 | Cost: 30336.4238 Epoch: 016/050 | Batch 150/469 | Cost: 27339.8887 Epoch: 016/050 | Batch 200/469 | Cost: 27272.3281 Epoch: 016/050 | Batch 250/469 | Cost: 26689.3438 Epoch: 016/050 | Batch 300/469 | Cost: 27881.7383 Epoch: 016/050 | Batch 350/469 | Cost: 27540.2598 Epoch: 016/050 | Batch 400/469 | Cost: 28315.9961 Epoch: 016/050 | Batch 450/469 | Cost: 27461.7090 Time elapsed: 2.51 min Epoch: 017/050 | Batch 000/469 | Cost: 26608.0723 Epoch: 017/050 | Batch 050/469 | Cost: 28546.3730 Epoch: 017/050 | Batch 100/469 | Cost: 26661.4219 Epoch: 017/050 | Batch 150/469 | Cost: 27780.3164 Epoch: 017/050 | Batch 200/469 | Cost: 27309.1055 Epoch: 017/050 | Batch 250/469 | Cost: 27510.1016 Epoch: 017/050 | Batch 300/469 | Cost: 27377.1172 Epoch: 017/050 | Batch 350/469 | Cost: 27471.3945 Epoch: 017/050 | Batch 400/469 | Cost: 30324.5586 Epoch: 017/050 | Batch 450/469 | Cost: 73500.2812 Time elapsed: 2.67 min Epoch: 018/050 | Batch 000/469 | Cost: 42388.4688 Epoch: 018/050 | Batch 050/469 | Cost: 27849.8965 Epoch: 018/050 | Batch 100/469 | Cost: 27593.3828 Epoch: 018/050 | Batch 150/469 | Cost: 27052.1992 Epoch: 018/050 | Batch 200/469 | Cost: 27993.3262 Epoch: 018/050 | Batch 250/469 | Cost: 27613.9766 Epoch: 018/050 | Batch 300/469 | Cost: 26225.6758 Epoch: 018/050 | Batch 350/469 | Cost: 26985.5664 Epoch: 018/050 | Batch 400/469 | Cost: 27585.4297 Epoch: 018/050 | Batch 450/469 | Cost: 27168.5215 Time elapsed: 2.83 min Epoch: 019/050 | Batch 000/469 | Cost: 27300.5352 Epoch: 019/050 | Batch 050/469 | Cost: 26698.8320 Epoch: 019/050 | Batch 100/469 | Cost: 27803.5508 Epoch: 019/050 | Batch 150/469 | Cost: 27134.0371 Epoch: 019/050 | Batch 200/469 | Cost: 25812.2949 Epoch: 019/050 | Batch 250/469 | Cost: 25544.7246 Epoch: 019/050 | Batch 300/469 | Cost: 25584.1641 Epoch: 019/050 | Batch 350/469 | Cost: 27084.9492 Epoch: 019/050 | Batch 400/469 | Cost: 27120.7324 Epoch: 019/050 | Batch 450/469 | Cost: 27272.8809 Time elapsed: 2.98 min Epoch: 020/050 | Batch 000/469 | Cost: 27053.4453 Epoch: 020/050 | Batch 050/469 | Cost: 27369.2656 Epoch: 020/050 | Batch 100/469 | Cost: 26794.1035 Epoch: 020/050 | Batch 150/469 | Cost: 26575.0957 Epoch: 020/050 | Batch 200/469 | Cost: 27841.3066 Epoch: 020/050 | Batch 250/469 | Cost: 26915.7656 Epoch: 020/050 | Batch 300/469 | Cost: 25307.2305 Epoch: 020/050 | Batch 350/469 | Cost: 26626.6621 Epoch: 020/050 | Batch 400/469 | Cost: 27154.6895 Epoch: 020/050 | Batch 450/469 | Cost: 27053.9160 Time elapsed: 3.14 min Epoch: 021/050 | Batch 000/469 | Cost: 27955.6387 Epoch: 021/050 | Batch 050/469 | Cost: 27744.2793 Epoch: 021/050 | Batch 100/469 | Cost: 26278.8652 Epoch: 021/050 | Batch 150/469 | Cost: 26567.6895 Epoch: 021/050 | Batch 200/469 | Cost: 26069.6836 Epoch: 021/050 | Batch 250/469 | Cost: 26567.5332 Epoch: 021/050 | Batch 300/469 | Cost: 26253.0352 Epoch: 021/050 | Batch 350/469 | Cost: 27332.9531 Epoch: 021/050 | Batch 400/469 | Cost: 26668.5371 Epoch: 021/050 | Batch 450/469 | Cost: 25928.0664 Time elapsed: 3.30 min Epoch: 022/050 | Batch 000/469 | Cost: 26090.3027 Epoch: 022/050 | Batch 050/469 | Cost: 26201.6992 Epoch: 022/050 | Batch 100/469 | Cost: 26070.0645 Epoch: 022/050 | Batch 150/469 | Cost: 27712.2090 Epoch: 022/050 | Batch 200/469 | Cost: 27064.6074 Epoch: 022/050 | Batch 250/469 | Cost: 25590.2988 Epoch: 022/050 | Batch 300/469 | Cost: 26056.4648 Epoch: 022/050 | Batch 350/469 | Cost: 25239.4082 Epoch: 022/050 | Batch 400/469 | Cost: 27711.3926 Epoch: 022/050 | Batch 450/469 | Cost: 26541.3496 Time elapsed: 3.46 min Epoch: 023/050 | Batch 000/469 | Cost: 26866.8457 Epoch: 023/050 | Batch 050/469 | Cost: 26516.2324 Epoch: 023/050 | Batch 100/469 | Cost: 27288.9688 Epoch: 023/050 | Batch 150/469 | Cost: 26922.4766 Epoch: 023/050 | Batch 200/469 | Cost: 26217.9844 Epoch: 023/050 | Batch 250/469 | Cost: 26235.7891 Epoch: 023/050 | Batch 300/469 | Cost: 26025.4102 Epoch: 023/050 | Batch 350/469 | Cost: 26741.8125 Epoch: 023/050 | Batch 400/469 | Cost: 26891.1172 Epoch: 023/050 | Batch 450/469 | Cost: 25617.7754 Time elapsed: 3.61 min Epoch: 024/050 | Batch 000/469 | Cost: 26568.8223 Epoch: 024/050 | Batch 050/469 | Cost: 25969.9336 Epoch: 024/050 | Batch 100/469 | Cost: 27559.0918 Epoch: 024/050 | Batch 150/469 | Cost: 27023.5234 Epoch: 024/050 | Batch 200/469 | Cost: 25339.0430 Epoch: 024/050 | Batch 250/469 | Cost: 26641.5664 Epoch: 024/050 | Batch 300/469 | Cost: 26808.2676 Epoch: 024/050 | Batch 350/469 | Cost: 25360.6113 Epoch: 024/050 | Batch 400/469 | Cost: 25561.7832 Epoch: 024/050 | Batch 450/469 | Cost: 25617.8594 Time elapsed: 3.77 min Epoch: 025/050 | Batch 000/469 | Cost: 26287.0703 Epoch: 025/050 | Batch 050/469 | Cost: 25861.7852 Epoch: 025/050 | Batch 100/469 | Cost: 27369.2188 Epoch: 025/050 | Batch 150/469 | Cost: 26051.3398 Epoch: 025/050 | Batch 200/469 | Cost: 25579.8047 Epoch: 025/050 | Batch 250/469 | Cost: 25831.1133 Epoch: 025/050 | Batch 300/469 | Cost: 25919.6426 Epoch: 025/050 | Batch 350/469 | Cost: 26633.8320 Epoch: 025/050 | Batch 400/469 | Cost: 26041.8867 Epoch: 025/050 | Batch 450/469 | Cost: 24800.4590 Time elapsed: 3.93 min Epoch: 026/050 | Batch 000/469 | Cost: 25229.7207 Epoch: 026/050 | Batch 050/469 | Cost: 25182.7578 Epoch: 026/050 | Batch 100/469 | Cost: 25985.1465 Epoch: 026/050 | Batch 150/469 | Cost: 25212.9902 Epoch: 026/050 | Batch 200/469 | Cost: 26574.4395 Epoch: 026/050 | Batch 250/469 | Cost: 25608.6055 Epoch: 026/050 | Batch 300/469 | Cost: 103743.6641 Epoch: 026/050 | Batch 350/469 | Cost: 29513.8047 Epoch: 026/050 | Batch 400/469 | Cost: 27109.1582 Epoch: 026/050 | Batch 450/469 | Cost: 26289.1582 Time elapsed: 4.08 min Epoch: 027/050 | Batch 000/469 | Cost: 27039.6719 Epoch: 027/050 | Batch 050/469 | Cost: 26615.5312 Epoch: 027/050 | Batch 100/469 | Cost: 25575.8086 Epoch: 027/050 | Batch 150/469 | Cost: 27047.0000 Epoch: 027/050 | Batch 200/469 | Cost: 26100.5879 Epoch: 027/050 | Batch 250/469 | Cost: 26004.3633 Epoch: 027/050 | Batch 300/469 | Cost: 27638.7969 Epoch: 027/050 | Batch 350/469 | Cost: 26723.9863 Epoch: 027/050 | Batch 400/469 | Cost: 25743.5000 Epoch: 027/050 | Batch 450/469 | Cost: 25969.0234 Time elapsed: 4.24 min Epoch: 028/050 | Batch 000/469 | Cost: 24330.8633 Epoch: 028/050 | Batch 050/469 | Cost: 25649.6270 Epoch: 028/050 | Batch 100/469 | Cost: 25380.6094 Epoch: 028/050 | Batch 150/469 | Cost: 25158.7676 Epoch: 028/050 | Batch 200/469 | Cost: 24840.2871 Epoch: 028/050 | Batch 250/469 | Cost: 25271.3105 Epoch: 028/050 | Batch 300/469 | Cost: 24288.1465 Epoch: 028/050 | Batch 350/469 | Cost: 24853.2852 Epoch: 028/050 | Batch 400/469 | Cost: 26212.7070 Epoch: 028/050 | Batch 450/469 | Cost: 26409.6934 Time elapsed: 4.40 min Epoch: 029/050 | Batch 000/469 | Cost: 26000.0156 Epoch: 029/050 | Batch 050/469 | Cost: 25050.6719 Epoch: 029/050 | Batch 100/469 | Cost: 25016.0645 Epoch: 029/050 | Batch 150/469 | Cost: 25192.9238 Epoch: 029/050 | Batch 200/469 | Cost: 25538.5840 Epoch: 029/050 | Batch 250/469 | Cost: 26724.6504 Epoch: 029/050 | Batch 300/469 | Cost: 26444.7988 Epoch: 029/050 | Batch 350/469 | Cost: 25800.6934 Epoch: 029/050 | Batch 400/469 | Cost: 24840.4258 Epoch: 029/050 | Batch 450/469 | Cost: 26121.0117 Time elapsed: 4.55 min Epoch: 030/050 | Batch 000/469 | Cost: 25889.4258 Epoch: 030/050 | Batch 050/469 | Cost: 24564.8105 Epoch: 030/050 | Batch 100/469 | Cost: 25197.3223 Epoch: 030/050 | Batch 150/469 | Cost: 25926.4180 Epoch: 030/050 | Batch 200/469 | Cost: 24823.6055 Epoch: 030/050 | Batch 250/469 | Cost: 24570.4961 Epoch: 030/050 | Batch 300/469 | Cost: 26450.1582 Epoch: 030/050 | Batch 350/469 | Cost: 25991.0820 Epoch: 030/050 | Batch 400/469 | Cost: 27038.8398 Epoch: 030/050 | Batch 450/469 | Cost: 25465.6094 Time elapsed: 4.71 min Epoch: 031/050 | Batch 000/469 | Cost: 25552.3574 Epoch: 031/050 | Batch 050/469 | Cost: 25650.0840 Epoch: 031/050 | Batch 100/469 | Cost: 25189.4062 Epoch: 031/050 | Batch 150/469 | Cost: 26423.2188 Epoch: 031/050 | Batch 200/469 | Cost: 24858.3926 Epoch: 031/050 | Batch 250/469 | Cost: 26807.5215 Epoch: 031/050 | Batch 300/469 | Cost: 26289.6484 Epoch: 031/050 | Batch 350/469 | Cost: 26251.7109 Epoch: 031/050 | Batch 400/469 | Cost: 25341.6426 Epoch: 031/050 | Batch 450/469 | Cost: 25598.0586 Time elapsed: 4.87 min Epoch: 032/050 | Batch 000/469 | Cost: 25554.5430 Epoch: 032/050 | Batch 050/469 | Cost: 27328.4414 Epoch: 032/050 | Batch 100/469 | Cost: 25416.3203 Epoch: 032/050 | Batch 150/469 | Cost: 26040.9531 Epoch: 032/050 | Batch 200/469 | Cost: 25655.6426 Epoch: 032/050 | Batch 250/469 | Cost: 26179.0469 Epoch: 032/050 | Batch 300/469 | Cost: 25275.5391 Epoch: 032/050 | Batch 350/469 | Cost: 24778.6836 Epoch: 032/050 | Batch 400/469 | Cost: 25070.4062 Epoch: 032/050 | Batch 450/469 | Cost: 25324.0215 Time elapsed: 5.02 min Epoch: 033/050 | Batch 000/469 | Cost: 24642.3848 Epoch: 033/050 | Batch 050/469 | Cost: 24271.6816 Epoch: 033/050 | Batch 100/469 | Cost: 25492.1836 Epoch: 033/050 | Batch 150/469 | Cost: 25345.7363 Epoch: 033/050 | Batch 200/469 | Cost: 25483.8418 Epoch: 033/050 | Batch 250/469 | Cost: 25023.7363 Epoch: 033/050 | Batch 300/469 | Cost: 24663.6562 Epoch: 033/050 | Batch 350/469 | Cost: 25393.5957 Epoch: 033/050 | Batch 400/469 | Cost: 25995.7422 Epoch: 033/050 | Batch 450/469 | Cost: 25259.6289 Time elapsed: 5.18 min Epoch: 034/050 | Batch 000/469 | Cost: 22940.1719 Epoch: 034/050 | Batch 050/469 | Cost: 24678.4434 Epoch: 034/050 | Batch 100/469 | Cost: 24350.1992 Epoch: 034/050 | Batch 150/469 | Cost: 24904.4062 Epoch: 034/050 | Batch 200/469 | Cost: 26084.8457 Epoch: 034/050 | Batch 250/469 | Cost: 25387.4746 Epoch: 034/050 | Batch 300/469 | Cost: 25201.3711 Epoch: 034/050 | Batch 350/469 | Cost: 26104.6172 Epoch: 034/050 | Batch 400/469 | Cost: 25046.2305 Epoch: 034/050 | Batch 450/469 | Cost: 24662.4102 Time elapsed: 5.34 min Epoch: 035/050 | Batch 000/469 | Cost: 24960.4844 Epoch: 035/050 | Batch 050/469 | Cost: 24758.6230 Epoch: 035/050 | Batch 100/469 | Cost: 25567.4766 Epoch: 035/050 | Batch 150/469 | Cost: 26101.2344 Epoch: 035/050 | Batch 200/469 | Cost: 37525.5625 Epoch: 035/050 | Batch 250/469 | Cost: 25834.8398 Epoch: 035/050 | Batch 300/469 | Cost: 25407.0664 Epoch: 035/050 | Batch 350/469 | Cost: 24773.7715 Epoch: 035/050 | Batch 400/469 | Cost: 24848.2070 Epoch: 035/050 | Batch 450/469 | Cost: 25452.8770 Time elapsed: 5.50 min Epoch: 036/050 | Batch 000/469 | Cost: 25342.3438 Epoch: 036/050 | Batch 050/469 | Cost: 24279.4141 Epoch: 036/050 | Batch 100/469 | Cost: 24467.3691 Epoch: 036/050 | Batch 150/469 | Cost: 24398.2832 Epoch: 036/050 | Batch 200/469 | Cost: 24147.9492 Epoch: 036/050 | Batch 250/469 | Cost: 24729.1992 Epoch: 036/050 | Batch 300/469 | Cost: 24632.2969 Epoch: 036/050 | Batch 350/469 | Cost: 24841.6211 Epoch: 036/050 | Batch 400/469 | Cost: 24582.1094 Epoch: 036/050 | Batch 450/469 | Cost: 24980.0625 Time elapsed: 5.65 min Epoch: 037/050 | Batch 000/469 | Cost: 25203.0020 Epoch: 037/050 | Batch 050/469 | Cost: 23087.3906 Epoch: 037/050 | Batch 100/469 | Cost: 24115.1836 Epoch: 037/050 | Batch 150/469 | Cost: 24795.7891 Epoch: 037/050 | Batch 200/469 | Cost: 27220.8164 Epoch: 037/050 | Batch 250/469 | Cost: 24840.1055 Epoch: 037/050 | Batch 300/469 | Cost: 25391.5098 Epoch: 037/050 | Batch 350/469 | Cost: 25561.5938 Epoch: 037/050 | Batch 400/469 | Cost: 23791.0605 Epoch: 037/050 | Batch 450/469 | Cost: 24261.7539 Time elapsed: 5.81 min Epoch: 038/050 | Batch 000/469 | Cost: 23855.5293 Epoch: 038/050 | Batch 050/469 | Cost: 25037.2031 Epoch: 038/050 | Batch 100/469 | Cost: 25081.6836 Epoch: 038/050 | Batch 150/469 | Cost: 24726.2656 Epoch: 038/050 | Batch 200/469 | Cost: 26345.6641 Epoch: 038/050 | Batch 250/469 | Cost: 24811.2539 Epoch: 038/050 | Batch 300/469 | Cost: 24353.3047 Epoch: 038/050 | Batch 350/469 | Cost: 25306.9180 Epoch: 038/050 | Batch 400/469 | Cost: 24490.6641 Epoch: 038/050 | Batch 450/469 | Cost: 25235.3613 Time elapsed: 5.97 min Epoch: 039/050 | Batch 000/469 | Cost: 24276.7832 Epoch: 039/050 | Batch 050/469 | Cost: 24525.7070 Epoch: 039/050 | Batch 100/469 | Cost: 24906.1289 Epoch: 039/050 | Batch 150/469 | Cost: 24968.1094 Epoch: 039/050 | Batch 200/469 | Cost: 24574.9062 Epoch: 039/050 | Batch 250/469 | Cost: 24858.0703 Epoch: 039/050 | Batch 300/469 | Cost: 25797.6152 Epoch: 039/050 | Batch 350/469 | Cost: 23874.2402 Epoch: 039/050 | Batch 400/469 | Cost: 25120.7891 Epoch: 039/050 | Batch 450/469 | Cost: 23778.7520 Time elapsed: 6.13 min Epoch: 040/050 | Batch 000/469 | Cost: 24705.1719 Epoch: 040/050 | Batch 050/469 | Cost: 24627.0195 Epoch: 040/050 | Batch 100/469 | Cost: 24295.2754 Epoch: 040/050 | Batch 150/469 | Cost: 24087.8906 Epoch: 040/050 | Batch 200/469 | Cost: 25491.7715 Epoch: 040/050 | Batch 250/469 | Cost: 24501.0703 Epoch: 040/050 | Batch 300/469 | Cost: 26422.9824 Epoch: 040/050 | Batch 350/469 | Cost: 25514.8086 Epoch: 040/050 | Batch 400/469 | Cost: 25690.4043 Epoch: 040/050 | Batch 450/469 | Cost: 24029.9238 Time elapsed: 6.28 min Epoch: 041/050 | Batch 000/469 | Cost: 24140.9375 Epoch: 041/050 | Batch 050/469 | Cost: 24123.9629 Epoch: 041/050 | Batch 100/469 | Cost: 24918.5645 Epoch: 041/050 | Batch 150/469 | Cost: 24718.9492 Epoch: 041/050 | Batch 200/469 | Cost: 24464.7383 Epoch: 041/050 | Batch 250/469 | Cost: 23528.3867 Epoch: 041/050 | Batch 300/469 | Cost: 24874.0156 Epoch: 041/050 | Batch 350/469 | Cost: 24976.7266 Epoch: 041/050 | Batch 400/469 | Cost: 24297.3750 Epoch: 041/050 | Batch 450/469 | Cost: 24892.3906 Time elapsed: 6.44 min Epoch: 042/050 | Batch 000/469 | Cost: 23911.4434 Epoch: 042/050 | Batch 050/469 | Cost: 24480.0840 Epoch: 042/050 | Batch 100/469 | Cost: 24132.7617 Epoch: 042/050 | Batch 150/469 | Cost: 26151.5430 Epoch: 042/050 | Batch 200/469 | Cost: 24691.4297 Epoch: 042/050 | Batch 250/469 | Cost: 32332.8027 Epoch: 042/050 | Batch 300/469 | Cost: 26727.7930 Epoch: 042/050 | Batch 350/469 | Cost: 24738.1543 Epoch: 042/050 | Batch 400/469 | Cost: 25356.9297 Epoch: 042/050 | Batch 450/469 | Cost: 25567.9824 Time elapsed: 6.60 min Epoch: 043/050 | Batch 000/469 | Cost: 24626.9219 Epoch: 043/050 | Batch 050/469 | Cost: 24583.6211 Epoch: 043/050 | Batch 100/469 | Cost: 23771.1250 Epoch: 043/050 | Batch 150/469 | Cost: 23740.8223 Epoch: 043/050 | Batch 200/469 | Cost: 25974.3535 Epoch: 043/050 | Batch 250/469 | Cost: 24506.7715 Epoch: 043/050 | Batch 300/469 | Cost: 24850.9629 Epoch: 043/050 | Batch 350/469 | Cost: 23420.8887 Epoch: 043/050 | Batch 400/469 | Cost: 23890.6582 Epoch: 043/050 | Batch 450/469 | Cost: 24406.9375 Time elapsed: 6.76 min Epoch: 044/050 | Batch 000/469 | Cost: 24515.1172 Epoch: 044/050 | Batch 050/469 | Cost: 24865.0742 Epoch: 044/050 | Batch 100/469 | Cost: 24439.4609 Epoch: 044/050 | Batch 150/469 | Cost: 24490.3047 Epoch: 044/050 | Batch 200/469 | Cost: 23753.9219 Epoch: 044/050 | Batch 250/469 | Cost: 23811.8926 Epoch: 044/050 | Batch 300/469 | Cost: 24070.1172 Epoch: 044/050 | Batch 350/469 | Cost: 24404.0664 Epoch: 044/050 | Batch 400/469 | Cost: 25219.6699 Epoch: 044/050 | Batch 450/469 | Cost: 23585.7500 Time elapsed: 6.91 min Epoch: 045/050 | Batch 000/469 | Cost: 23822.3262 Epoch: 045/050 | Batch 050/469 | Cost: 23653.2695 Epoch: 045/050 | Batch 100/469 | Cost: 25814.4062 Epoch: 045/050 | Batch 150/469 | Cost: 23872.3867 Epoch: 045/050 | Batch 200/469 | Cost: 25231.3008 Epoch: 045/050 | Batch 250/469 | Cost: 24211.3652 Epoch: 045/050 | Batch 300/469 | Cost: 24564.8242 Epoch: 045/050 | Batch 350/469 | Cost: 23450.6211 Epoch: 045/050 | Batch 400/469 | Cost: 24501.6504 Epoch: 045/050 | Batch 450/469 | Cost: 26215.8633 Time elapsed: 7.07 min Epoch: 046/050 | Batch 000/469 | Cost: 24400.6562 Epoch: 046/050 | Batch 050/469 | Cost: 24448.3691 Epoch: 046/050 | Batch 100/469 | Cost: 24466.0859 Epoch: 046/050 | Batch 150/469 | Cost: 24153.8711 Epoch: 046/050 | Batch 200/469 | Cost: 24351.0098 Epoch: 046/050 | Batch 250/469 | Cost: 23123.2500 Epoch: 046/050 | Batch 300/469 | Cost: 24734.2773 Epoch: 046/050 | Batch 350/469 | Cost: 23785.1875 Epoch: 046/050 | Batch 400/469 | Cost: 24901.5039 Epoch: 046/050 | Batch 450/469 | Cost: 23700.1133 Time elapsed: 7.22 min Epoch: 047/050 | Batch 000/469 | Cost: 25294.2520 Epoch: 047/050 | Batch 050/469 | Cost: 24074.6992 Epoch: 047/050 | Batch 100/469 | Cost: 24112.8848 Epoch: 047/050 | Batch 150/469 | Cost: 24861.8926 Epoch: 047/050 | Batch 200/469 | Cost: 22852.8594 Epoch: 047/050 | Batch 250/469 | Cost: 23799.0703 Epoch: 047/050 | Batch 300/469 | Cost: 23758.0039 Epoch: 047/050 | Batch 350/469 | Cost: 23628.5391 Epoch: 047/050 | Batch 400/469 | Cost: 23933.1504 Epoch: 047/050 | Batch 450/469 | Cost: 22900.7715 Time elapsed: 7.38 min Epoch: 048/050 | Batch 000/469 | Cost: 23949.8223 Epoch: 048/050 | Batch 050/469 | Cost: 24267.9609 Epoch: 048/050 | Batch 100/469 | Cost: 22838.5234 Epoch: 048/050 | Batch 150/469 | Cost: 24212.3223 Epoch: 048/050 | Batch 200/469 | Cost: 23809.5449 Epoch: 048/050 | Batch 250/469 | Cost: 23827.1680 Epoch: 048/050 | Batch 300/469 | Cost: 25127.4844 Epoch: 048/050 | Batch 350/469 | Cost: 23184.9473 Epoch: 048/050 | Batch 400/469 | Cost: 24065.0840 Epoch: 048/050 | Batch 450/469 | Cost: 23201.5645 Time elapsed: 7.53 min Epoch: 049/050 | Batch 000/469 | Cost: 23682.0781 Epoch: 049/050 | Batch 050/469 | Cost: 23740.3887 Epoch: 049/050 | Batch 100/469 | Cost: 23290.7441 Epoch: 049/050 | Batch 150/469 | Cost: 23001.3262 Epoch: 049/050 | Batch 200/469 | Cost: 23265.8105 Epoch: 049/050 | Batch 250/469 | Cost: 22163.1328 Epoch: 049/050 | Batch 300/469 | Cost: 24283.0508 Epoch: 049/050 | Batch 350/469 | Cost: 23822.0098 Epoch: 049/050 | Batch 400/469 | Cost: 22784.8594 Epoch: 049/050 | Batch 450/469 | Cost: 24202.4961 Time elapsed: 7.69 min Epoch: 050/050 | Batch 000/469 | Cost: 23966.5840 Epoch: 050/050 | Batch 050/469 | Cost: 24665.5449 Epoch: 050/050 | Batch 100/469 | Cost: 23895.6406 Epoch: 050/050 | Batch 150/469 | Cost: 24318.3926 Epoch: 050/050 | Batch 200/469 | Cost: 23685.9727 Epoch: 050/050 | Batch 250/469 | Cost: 23648.9336 Epoch: 050/050 | Batch 300/469 | Cost: 23634.2500 Epoch: 050/050 | Batch 350/469 | Cost: 27888.9062 Epoch: 050/050 | Batch 400/469 | Cost: 23649.8242 Epoch: 050/050 | Batch 450/469 | Cost: 22728.7891 Time elapsed: 7.85 min Total Training Time: 7.85 min
%matplotlib inline
import matplotlib.pyplot as plt
##########################
### VISUALIZATION
##########################
n_images = 15
image_width = 28
fig, axes = plt.subplots(nrows=2, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images, 0]
for i in range(n_images):
for ax, img in zip(axes, [orig_images, decoded_images]):
ax[i].imshow(img[i].detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')
for i in range(10):
##########################
### RANDOM SAMPLE
##########################
labels = torch.tensor([i]*10).to(device)
n_images = labels.size()[0]
rand_features = torch.randn(n_images, num_latent).to(device)
new_images = model.decoder(rand_features, labels)
##########################
### VISUALIZATION
##########################
image_width = 28
fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)
decoded_images = new_images[:n_images, 0]
print('Class Label %d' % i)
for ax, img in zip(axes, decoded_images):
ax.imshow(img.detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')
plt.show()
Class Label 0
Class Label 1
Class Label 2
Class Label 3
Class Label 4
Class Label 5
Class Label 6
Class Label 7
Class Label 8
Class Label 9
%watermark -iv
numpy 1.15.4 torch 1.0.0