Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.9.0 torch 1.3.0
A convolutional autoencoder using deconvolutional layers that compresses 768-pixel MNIST images down to a 7x7x8 (392 pixel) representation.
This convolutional VAE uses a continuous Jaccard distance. I.e., given 2 vectors, x and y:
J(x,y)=1−∑imin(xi,yi)∑imax(xi,yi)Reference:
import torch
def continuous_jaccard(x, y):
"""
Implementation of the continuous version of the
Jaccard distance:
1 - [sum_i min(x_i, y_i)] / [sum_i max(x_i, y_i)]
"""
c = torch.cat((x.view(-1).unsqueeze(1), y.view(-1).unsqueeze(1)), dim=1)
numerator = torch.sum(torch.min(c, dim=1)[0])
denominator = torch.sum(torch.max(c, dim=1)[0])
return 1. - numerator/denominator
# Example
x = torch.tensor([7, 2, 3, 4, 5, 6]).float()
y = torch.tensor([1, 8, 9, 10, 11, 4]).float()
continuous_jaccard(x, y)
tensor(0.6275)
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
##########################
### SETTINGS
##########################
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
# Hyperparameters
random_seed = 456
learning_rate = 0.005
num_epochs = 10
batch_size = 128
##########################
### MNIST DATASET
##########################
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
0it [00:00, ?it/s]
Device: cuda:0 Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
9920512it [00:02, 3410868.60it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
32768it [00:00, 280881.47it/s] 0it [00:00, ?it/s]
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
1654784it [00:00, 1928783.37it/s] 8192it [00:00, 113077.53it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw Processing... Done! Image batch dimensions: torch.Size([128, 1, 28, 28]) Image label dimensions: torch.Size([128])
##########################
### MODEL
##########################
class ConvolutionalAutoencoder(torch.nn.Module):
def __init__(self):
super(ConvolutionalAutoencoder, self).__init__()
# calculate same padding:
# (w - k + 2*p)/s + 1 = o
# => p = (s(o-1) - w + k)/2
### ENCODER
# 28x28x1 => 28x28x4
self.conv_1 = torch.nn.Conv2d(in_channels=1,
out_channels=4,
kernel_size=(3, 3),
stride=(1, 1),
# (1(28-1) - 28 + 3) / 2 = 1
padding=1)
# 28x28x4 => 14x14x4
self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2),
# (2(14-1) - 28 + 2) / 2 = 0
padding=0)
# 14x14x4 => 14x14x8
self.conv_2 = torch.nn.Conv2d(in_channels=4,
out_channels=8,
kernel_size=(3, 3),
stride=(1, 1),
# (1(14-1) - 14 + 3) / 2 = 1
padding=1)
# 14x14x8 => 7x7x8
self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2),
# (2(7-1) - 14 + 2) / 2 = 0
padding=0)
### DECODER
# 7x7x8 => 15x15x4
self.deconv_1 = torch.nn.ConvTranspose2d(in_channels=8,
out_channels=4,
kernel_size=(3, 3),
stride=(2, 2),
padding=0)
# 15x15x4 => 31x31x1
self.deconv_2 = torch.nn.ConvTranspose2d(in_channels=4,
out_channels=1,
kernel_size=(3, 3),
stride=(2, 2),
padding=0)
def forward(self, x):
### ENCODER
x = self.conv_1(x)
x = F.leaky_relu(x)
x = self.pool_1(x)
x = self.conv_2(x)
x = F.leaky_relu(x)
x = self.pool_2(x)
### DECODER
x = self.deconv_1(x)
x = F.leaky_relu(x)
x = self.deconv_2(x)
x = F.leaky_relu(x)
logits = x[:, :, 2:30, 2:30]
probas = torch.sigmoid(logits)
return logits, probas
torch.manual_seed(random_seed)
model = ConvolutionalAutoencoder()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = time.time()
for epoch in range(num_epochs):
for batch_idx, (features, targets) in enumerate(train_loader):
# don't need labels, only the images (features)
features = features.to(device)
### FORWARD AND BACK PROP
logits, decoded = model(features)
#cost = F.binary_cross_entropy_with_logits(logits, features)
cost = continuous_jaccard(features, decoded)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_dataset)//batch_size, cost))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/010 | Batch 000/468 | Cost: 0.8663 Epoch: 001/010 | Batch 050/468 | Cost: 0.8086 Epoch: 001/010 | Batch 100/468 | Cost: 0.7729 Epoch: 001/010 | Batch 150/468 | Cost: 0.7322 Epoch: 001/010 | Batch 200/468 | Cost: 0.3983 Epoch: 001/010 | Batch 250/468 | Cost: 0.2963 Epoch: 001/010 | Batch 300/468 | Cost: 0.2927 Epoch: 001/010 | Batch 350/468 | Cost: 0.2783 Epoch: 001/010 | Batch 400/468 | Cost: 0.2780 Epoch: 001/010 | Batch 450/468 | Cost: 0.2609 Time elapsed: 0.14 min Epoch: 002/010 | Batch 000/468 | Cost: 0.2694 Epoch: 002/010 | Batch 050/468 | Cost: 0.2671 Epoch: 002/010 | Batch 100/468 | Cost: 0.2444 Epoch: 002/010 | Batch 150/468 | Cost: 0.2378 Epoch: 002/010 | Batch 200/468 | Cost: 0.2540 Epoch: 002/010 | Batch 250/468 | Cost: 0.2515 Epoch: 002/010 | Batch 300/468 | Cost: 0.2393 Epoch: 002/010 | Batch 350/468 | Cost: 0.2528 Epoch: 002/010 | Batch 400/468 | Cost: 0.2283 Epoch: 002/010 | Batch 450/468 | Cost: 0.2420 Time elapsed: 0.27 min Epoch: 003/010 | Batch 000/468 | Cost: 0.2317 Epoch: 003/010 | Batch 050/468 | Cost: 0.2274 Epoch: 003/010 | Batch 100/468 | Cost: 0.2489 Epoch: 003/010 | Batch 150/468 | Cost: 0.2246 Epoch: 003/010 | Batch 200/468 | Cost: 0.2178 Epoch: 003/010 | Batch 250/468 | Cost: 0.2200 Epoch: 003/010 | Batch 300/468 | Cost: 0.2200 Epoch: 003/010 | Batch 350/468 | Cost: 0.2309 Epoch: 003/010 | Batch 400/468 | Cost: 0.2215 Epoch: 003/010 | Batch 450/468 | Cost: 0.2218 Time elapsed: 0.40 min Epoch: 004/010 | Batch 000/468 | Cost: 0.2124 Epoch: 004/010 | Batch 050/468 | Cost: 0.2191 Epoch: 004/010 | Batch 100/468 | Cost: 0.2121 Epoch: 004/010 | Batch 150/468 | Cost: 0.2184 Epoch: 004/010 | Batch 200/468 | Cost: 0.2118 Epoch: 004/010 | Batch 250/468 | Cost: 0.2090 Epoch: 004/010 | Batch 300/468 | Cost: 0.2114 Epoch: 004/010 | Batch 350/468 | Cost: 0.2150 Epoch: 004/010 | Batch 400/468 | Cost: 0.2218 Epoch: 004/010 | Batch 450/468 | Cost: 0.2015 Time elapsed: 0.53 min Epoch: 005/010 | Batch 000/468 | Cost: 0.1985 Epoch: 005/010 | Batch 050/468 | Cost: 0.2053 Epoch: 005/010 | Batch 100/468 | Cost: 0.2067 Epoch: 005/010 | Batch 150/468 | Cost: 0.2003 Epoch: 005/010 | Batch 200/468 | Cost: 0.2004 Epoch: 005/010 | Batch 250/468 | Cost: 0.2076 Epoch: 005/010 | Batch 300/468 | Cost: 0.2006 Epoch: 005/010 | Batch 350/468 | Cost: 0.2162 Epoch: 005/010 | Batch 400/468 | Cost: 0.2137 Epoch: 005/010 | Batch 450/468 | Cost: 0.2077 Time elapsed: 0.67 min Epoch: 006/010 | Batch 000/468 | Cost: 0.1986 Epoch: 006/010 | Batch 050/468 | Cost: 0.2048 Epoch: 006/010 | Batch 100/468 | Cost: 0.2063 Epoch: 006/010 | Batch 150/468 | Cost: 0.2069 Epoch: 006/010 | Batch 200/468 | Cost: 0.2092 Epoch: 006/010 | Batch 250/468 | Cost: 0.1947 Epoch: 006/010 | Batch 300/468 | Cost: 0.2006 Epoch: 006/010 | Batch 350/468 | Cost: 0.1927 Epoch: 006/010 | Batch 400/468 | Cost: 0.2018 Epoch: 006/010 | Batch 450/468 | Cost: 0.1964 Time elapsed: 0.79 min Epoch: 007/010 | Batch 000/468 | Cost: 0.1809 Epoch: 007/010 | Batch 050/468 | Cost: 0.1996 Epoch: 007/010 | Batch 100/468 | Cost: 0.1942 Epoch: 007/010 | Batch 150/468 | Cost: 0.1909 Epoch: 007/010 | Batch 200/468 | Cost: 0.1894 Epoch: 007/010 | Batch 250/468 | Cost: 0.1937 Epoch: 007/010 | Batch 300/468 | Cost: 0.1956 Epoch: 007/010 | Batch 350/468 | Cost: 0.1938 Epoch: 007/010 | Batch 400/468 | Cost: 0.1963 Epoch: 007/010 | Batch 450/468 | Cost: 0.2060 Time elapsed: 0.92 min Epoch: 008/010 | Batch 000/468 | Cost: 0.1947 Epoch: 008/010 | Batch 050/468 | Cost: 0.2044 Epoch: 008/010 | Batch 100/468 | Cost: 0.1811 Epoch: 008/010 | Batch 150/468 | Cost: 0.1980 Epoch: 008/010 | Batch 200/468 | Cost: 0.1794 Epoch: 008/010 | Batch 250/468 | Cost: 0.2008 Epoch: 008/010 | Batch 300/468 | Cost: 0.1949 Epoch: 008/010 | Batch 350/468 | Cost: 0.1843 Epoch: 008/010 | Batch 400/468 | Cost: 0.1942 Epoch: 008/010 | Batch 450/468 | Cost: 0.1932 Time elapsed: 1.05 min Epoch: 009/010 | Batch 000/468 | Cost: 0.1901 Epoch: 009/010 | Batch 050/468 | Cost: 0.1894 Epoch: 009/010 | Batch 100/468 | Cost: 0.1976 Epoch: 009/010 | Batch 150/468 | Cost: 0.1935 Epoch: 009/010 | Batch 200/468 | Cost: 0.1949 Epoch: 009/010 | Batch 250/468 | Cost: 0.1921 Epoch: 009/010 | Batch 300/468 | Cost: 0.1917 Epoch: 009/010 | Batch 350/468 | Cost: 0.1900 Epoch: 009/010 | Batch 400/468 | Cost: 0.1913 Epoch: 009/010 | Batch 450/468 | Cost: 0.1815 Time elapsed: 1.19 min Epoch: 010/010 | Batch 000/468 | Cost: 0.1845 Epoch: 010/010 | Batch 050/468 | Cost: 0.1910 Epoch: 010/010 | Batch 100/468 | Cost: 0.1929 Epoch: 010/010 | Batch 150/468 | Cost: 0.1919 Epoch: 010/010 | Batch 200/468 | Cost: 0.1822 Epoch: 010/010 | Batch 250/468 | Cost: 0.1974 Epoch: 010/010 | Batch 300/468 | Cost: 0.1919 Epoch: 010/010 | Batch 350/468 | Cost: 0.1750 Epoch: 010/010 | Batch 400/468 | Cost: 0.1879 Epoch: 010/010 | Batch 450/468 | Cost: 0.1785 Time elapsed: 1.32 min Total Training Time: 1.32 min
%matplotlib inline
import matplotlib.pyplot as plt
##########################
### VISUALIZATION
##########################
n_images = 15
image_width = 28
fig, axes = plt.subplots(nrows=2, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
decoded_images = decoded[:n_images]
for i in range(n_images):
for ax, img in zip(axes, [orig_images, decoded_images]):
curr_img = img[i].detach().to(torch.device('cpu'))
ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')
%watermark -iv
torch 1.3.0 matplotlib 3.1.0 torchvision 0.4.1a0+d94043a numpy 1.17.2