!pip install -Uq pip
!pip install -q torchsummary
import os
# import re
import itertools
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchsummary import summary
latent_dims = 128
num_epochs = 150
variational_beta = 1
batch_size = 256
capacity = 64
learning_rate = 1e-4
image_channels = 3
def get_default_device():
"""Pick GPU if available, else CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
device = get_default_device()
device
device(type='cuda')
def to_device(data, device):
if isinstance(data, (list, tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
class DeviceDataLoader():
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
return len(self.dl)
import cv2
def visualize_reconstruction(model, batch_images, iter_num, save_dir="./images"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with torch.no_grad():
recons_images, _, _ = model(batch_images)
img_ = make_grid(recons_images[:50].cpu().detach().clamp(0.0, 1.0) * 255.0, nrow=10).permute((1, 2, 0)).numpy()
cv2.imwrite(os.path.join(save_dir, f"Image_{iter_num:04}.png"), img_)
def generate_images(model, latent_vectors, iter_num, save_dir="./generated"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with torch.no_grad():
# model.to("cpu")
img_generated = model.decoder(latent_vectors)
img_ = make_grid(img_generated[:100].cpu().detach().clamp(0.0, 1.0) * 255.0, nrow=10).permute((1, 2, 0)).numpy()
cv2.imwrite(os.path.join(save_dir, f"Image_{iter_num:04}.png"), img_)
# model.to(device)
dataset_path = r"../input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled"
# print(os.listdir(dataset_path))
def get_image_paths(IMG_DIR):
total_files = 0
all_filenames = []
for root, _, files in itertools.islice(os.walk(IMG_DIR), 0, None):
if len(files) == 0:
continue
for file_name in files:
total_files += 1
file_path = os.path.join(root, file_name)
all_filenames.append(file_path)
print(f"Total Faces: {total_files}")
train_size = int(total_files * 0.9)
val_size = int(total_files * 0.1) + 1
# print(train_size, val_size, train_size + val_size)
train_set, val_set = random_split(all_filenames, [train_size, val_size])
print(f"Training Set: {len(train_set)}")
print(f"Validation Set: {len(val_set)}")
return train_set, val_set
train_paths, validation_paths = get_image_paths(dataset_path)
Total Faces: 13233 Training Set: 11909 Validation Set: 1324
class LBW_Dataset(Dataset):
'''
Parse raw data to form a Dataset of (X, y).
'''
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = Image.open(self.image_paths[idx])
img = img.resize((64, 64), Image.BICUBIC)
img = np.asarray(img, dtype='float32')
img = img / 255.0
img = self.transform(img)
return img
transformation = T.ToTensor()
train_ds = LBW_Dataset(train_paths, transformation)
validation_ds = LBW_Dataset(validation_paths, transformation)
len(train_ds), len(validation_ds)
(11909, 1324)
train_dl = DataLoader(train_ds, batch_size, shuffle=True,
num_workers=3, pin_memory=True)
valid_dl = DataLoader(validation_ds, batch_size*2,
num_workers=2, pin_memory=True)
# Move to device
train_loader = DeviceDataLoader(train_dl, device)
val_loader = DeviceDataLoader(valid_dl, device)
class Encoder(nn.Module):
def __init__(self):
super().__init__()
c = capacity
self.conv1 = nn.Conv2d(in_channels=image_channels, out_channels=c//2, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(in_channels=c//2, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
self.conv5 = nn.Conv2d(in_channels=c*2, out_channels=c*2, kernel_size=4, stride=2, padding=1)
self.fc_1 = nn.Linear(in_features=c*2*2*2, out_features=512)
self.fc_2 = nn.Linear(in_features=512, out_features=1024)
self.fc_3 = nn.Linear(in_features=1024, out_features=512)
self.softplus_operation = nn.Softplus()
self.fc_mu = nn.Linear(in_features=512, out_features=latent_dims)
self.fc_logvar = nn.Linear(in_features=512, out_features=latent_dims)
def forward(self, xb):
xb = F.relu(self.conv1(xb))
xb = F.relu(self.conv2(xb))
xb = F.relu(self.conv3(xb))
xb = F.relu(self.conv4(xb))
xb = F.relu(self.conv5(xb))
xb = xb.view(xb.size(0), -1)
xb_out = F.relu(self.fc_1(xb))
xb = F.relu(self.fc_2(xb_out))
xb = F.relu(self.fc_3(xb) + xb_out)
x_mu = self.fc_mu(xb)
x_logvar = self.softplus_operation(self.fc_logvar(xb))
return x_mu, x_logvar
class Decoder(nn.Module):
def __init__(self):
super().__init__()
c = capacity
self.fc4 = nn.Linear(latent_dims, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc1 = nn.Linear(512, 512)
# self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
self.upsampler = nn.Upsample(scale_factor=2, mode='nearest')
self.conv5 = nn.Conv2d(in_channels=c*2, out_channels=c*2, kernel_size=3, stride=1,padding=1)
self.conv4 = nn.Conv2d(in_channels=c*2, out_channels=c,kernel_size=3,stride=1,padding=1)
self.conv3 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=c, out_channels=c//2, kernel_size=3, stride=1,padding=1)
self.conv1 = nn.Conv2d(in_channels=c//2, out_channels=image_channels, kernel_size=3, stride=1,padding=1)
def forward(self, xb):
xb_out = F.relu(self.fc4(xb))
xb = F.relu(self.fc3(xb_out))
xb = F.relu(self.fc2(xb) + xb_out)
xb = F.relu(self.fc1(xb))
xb = xb.view(xb.size(0), capacity*2, 2, 2)
xb = self.upsampler(xb)
xb = F.relu(self.conv5(xb))
xb = self.upsampler(xb)
xb = F.relu(self.conv4(xb))
xb = self.upsampler(xb)
xb = F.relu(self.conv3(xb))
xb = self.upsampler(xb)
xb = F.relu(self.conv2(xb))
xb = self.upsampler(xb)
xb = torch.sigmoid(self.conv1(xb))
return xb
# self.conv3 = nn.ConvTranspose2d(in_channels=c, out_channels=c, kernel_size=4, stride=2, padding=1, output_padding=1)
# self.conv2 = nn.ConvTranspose2d(in_channels=c, out_channels=c, kernel_size=4, stride=2, padding=1)
# self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
# self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=c//4, kernel_size=4, stride=2, padding=1)
# self.fc_penultimate = nn.Linear(in_features=c//4*28*28, out_features=28*28)
# self.fc_final = nn.Linear(in_features=28*28, out_features=28*28)
# def forward(self, xb):
# xb = F.relu(self.fc_2(xb))
# xb = F.relu(self.fc_1(xb))
# xb = xb.view(xb.size(0), capacity, 3, 3)
# xb = F.relu(self.conv3(xb))
# xb = F.relu(self.conv2(xb))
# xb = torch.sigmoid(self.conv1(xb))
# xb = xb.view(xb.size(0), -1)
# xb = self.fc_penultimate(xb)
# xb = torch.sigmoid(self.fc_final(xb))
# xb = xb.view(xb.size(0), 1, 28, 28)
# return xb
def vae_loss(reconstruct, og, mu, logvar):
reconstruction_loss = F.binary_cross_entropy(reconstruct.view(-1, 12288), og.view(-1, 12288), reduction='sum')
# kl_divergence = 0.5 * torch.sum(torch.exp(logvar) + torch.square(mu) - 1 - logvar)
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return reconstruction_loss + variational_beta * kl_divergence
class VariationalAutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def latent_sample(self, mu, logvar):
sigma = torch.mul(logvar, 0.5)
eps = torch.randn_like(sigma)
sample = mu + (sigma * eps)
return sample
def forward(self, xb):
latent_mu, latent_std = self.encoder(xb)
latent_ = self.latent_sample(latent_mu, latent_std)
reconstruction = self.decoder(latent_)
return reconstruction, latent_mu, latent_std
def train_step(self, input_batch):
input_batch_reconstruct, batch_mu, batch_std = self(input_batch)
loss = vae_loss(input_batch_reconstruct, input_batch, batch_mu, batch_std)
return {"loss": loss}
def valid_step(self,val_batch):
with torch.no_grad():
val_batch_reconstruct, batch_mu, batch_std = self(val_batch)
loss = vae_loss(val_batch_reconstruct, val_batch, batch_mu, batch_std)
return {"val_loss": loss}
def get_metrics_epoch_end(self, outputs, validation=True):
if validation:
loss_ = 'val_loss'
else:
loss_ = 'loss'
batch_losses = [x[f'{loss_}'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
return {f'{loss_}': epoch_loss.item()}
def epoch_end(self, epoch, result):
print(f"Epoch [{epoch+1}] -> last_lr: {result['lrs'][-1]:.4f}, loss: {result['loss']:.4f}, val_loss: {result['val_loss']:.4f}")
# summary(to_device(Encoder(), device), (3, 64, 64))
# summary(to_device(Decoder(), device), (1, latent_dims))
summary(to_device(VariationalAutoEncoder(), device), (3, 64, 64))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 32, 32, 32] 1,568 Conv2d-2 [-1, 64, 16, 16] 32,832 Conv2d-3 [-1, 64, 8, 8] 65,600 Conv2d-4 [-1, 128, 4, 4] 131,200 Conv2d-5 [-1, 128, 2, 2] 262,272 Linear-6 [-1, 512] 262,656 Linear-7 [-1, 1024] 525,312 Linear-8 [-1, 512] 524,800 Linear-9 [-1, 128] 65,664 Linear-10 [-1, 128] 65,664 Softplus-11 [-1, 128] 0 Encoder-12 [[-1, 128], [-1, 128]] 0 Linear-13 [-1, 512] 66,048 Linear-14 [-1, 1024] 525,312 Linear-15 [-1, 512] 524,800 Linear-16 [-1, 512] 262,656 Upsample-17 [-1, 128, 4, 4] 0 Conv2d-18 [-1, 128, 4, 4] 147,584 Upsample-19 [-1, 128, 8, 8] 0 Conv2d-20 [-1, 64, 8, 8] 73,792 Upsample-21 [-1, 64, 16, 16] 0 Conv2d-22 [-1, 64, 16, 16] 36,928 Upsample-23 [-1, 64, 32, 32] 0 Conv2d-24 [-1, 32, 32, 32] 18,464 Upsample-25 [-1, 32, 64, 64] 0 Conv2d-26 [-1, 3, 64, 64] 867 Decoder-27 [-1, 3, 64, 64] 0 ================================================================ Total params: 3,594,019 Trainable params: 3,594,019 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.05 Forward/backward pass size (MB): 2.65 Params size (MB): 13.71 Estimated Total Size (MB): 16.41 ----------------------------------------------------------------
# # pip install -qU torchviz
# from torchviz import make_dot
# x = torch.zeros(1, 3, 64, 64, dtype=torch.float, requires_grad=False)
# model = Encoder()
# outs = model(x)
# make_dot(outs[0])
# input_names = ["Input"]
# output_names = ["Mu", "LogVar"]
# torch.onnx.export(model, x, 'Encoder.onnx', input_names=input_names, output_names=output_names)
!rm -rf ./images
!rm -rf ./generated
# for visualizing reconstruction
for val_batch in val_loader:
val_images = val_batch
break
# for generating new images
# latents = torch.nn.init.normal_(torch.FloatTensor(50, latent_dims), mean=0.0, std=1.0)
# latents = to_device(torch.FloatTensor(100, latent_dims).normal_(0.0, 2.0), device)
latents = torch.randn(128, latent_dims, device=device)
latents = to_device(latents, device)
def evaluate(model, val_loader):
model.eval()
outputs = [model.valid_step(val_batch) for val_batch in val_loader]
return model.get_metrics_epoch_end(outputs, validation=True)
def get_lr(optimizer: object) -> float:
''' Returns current learning rate'''
for param_group in optimizer.param_groups:
return param_group['lr']
def fit(epochs, lr, model, train_loader, val_loader, opt_func=None):
history = []
if not opt_func:
optimizer = torch.optim.SGD(model.parameters(), lr)
else:
optimizer = opt_func
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=0.1,
epochs=epochs,
steps_per_epoch=len(train_loader))
for epoch in range(epochs):
# Training Phase
train_history = []
lrs = []
# new image generation
generate_images(model, latents, epoch)
model.train()
for train_batch in train_loader:
info = model.train_step(train_batch)
loss = info['loss']
# contains batch loss for training phase
train_history.append(info)
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
lrs.append(get_lr(optimizer))
scheduler.step()
train_result = model.get_metrics_epoch_end(train_history, validation=False)
val_result = evaluate(model, val_loader)
result = {**train_result, **val_result}
result['lrs'] = lrs
visualize_reconstruction(model, val_images, epoch+1)
model.epoch_end(epoch, result)
history.append(result)
return history
model = VariationalAutoEncoder()
to_device(model, device)
summary(model, (3, 64, 64))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 32, 32, 32] 1,568 Conv2d-2 [-1, 64, 16, 16] 32,832 Conv2d-3 [-1, 64, 8, 8] 65,600 Conv2d-4 [-1, 128, 4, 4] 131,200 Conv2d-5 [-1, 128, 2, 2] 262,272 Linear-6 [-1, 512] 262,656 Linear-7 [-1, 1024] 525,312 Linear-8 [-1, 512] 524,800 Linear-9 [-1, 128] 65,664 Linear-10 [-1, 128] 65,664 Softplus-11 [-1, 128] 0 Encoder-12 [[-1, 128], [-1, 128]] 0 Linear-13 [-1, 512] 66,048 Linear-14 [-1, 1024] 525,312 Linear-15 [-1, 512] 524,800 Linear-16 [-1, 512] 262,656 Upsample-17 [-1, 128, 4, 4] 0 Conv2d-18 [-1, 128, 4, 4] 147,584 Upsample-19 [-1, 128, 8, 8] 0 Conv2d-20 [-1, 64, 8, 8] 73,792 Upsample-21 [-1, 64, 16, 16] 0 Conv2d-22 [-1, 64, 16, 16] 36,928 Upsample-23 [-1, 64, 32, 32] 0 Conv2d-24 [-1, 32, 32, 32] 18,464 Upsample-25 [-1, 32, 64, 64] 0 Conv2d-26 [-1, 3, 64, 64] 867 Decoder-27 [-1, 3, 64, 64] 0 ================================================================ Total params: 3,594,019 Trainable params: 3,594,019 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.05 Forward/backward pass size (MB): 2.65 Params size (MB): 13.71 Estimated Total Size (MB): 16.41 ----------------------------------------------------------------
history = [evaluate(model, val_loader)]
history
[{'val_loss': 3792803.0}]
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-4)
history = fit(150, learning_rate, model, train_loader, val_loader, optimizer)
Epoch [1] -> last_lr: 0.0041, loss: 2110909.5000, val_loss: 3457486.0000 Epoch [2] -> last_lr: 0.0045, loss: 1983661.2500, val_loss: 3480033.0000 Epoch [3] -> last_lr: 0.0050, loss: 1980685.0000, val_loss: 3379997.0000 Epoch [4] -> last_lr: 0.0058, loss: 1995939.6250, val_loss: 3559743.5000 Epoch [5] -> last_lr: 0.0069, loss: 2001709.8750, val_loss: 3383755.0000 Epoch [6] -> last_lr: 0.0081, loss: 1960210.3750, val_loss: 3472971.0000 Epoch [7] -> last_lr: 0.0096, loss: 1970040.5000, val_loss: 3361466.5000 Epoch [8] -> last_lr: 0.0113, loss: 2031578.3750, val_loss: 3544373.5000 Epoch [9] -> last_lr: 0.0131, loss: 2044796.2500, val_loss: 3506155.0000 Epoch [10] -> last_lr: 0.0152, loss: 2006746.5000, val_loss: 3479946.7500 Epoch [11] -> last_lr: 0.0174, loss: 2034279.8750, val_loss: 3469672.7500 Epoch [12] -> last_lr: 0.0198, loss: 2007291.8750, val_loss: 3508217.5000 Epoch [13] -> last_lr: 0.0224, loss: 2023592.1250, val_loss: 3483720.5000 Epoch [14] -> last_lr: 0.0251, loss: 2004951.7500, val_loss: 3524578.7500 Epoch [15] -> last_lr: 0.0280, loss: 2026781.3750, val_loss: 3500912.0000
torch.save(model, f"./vae_{latent_dims}.pt")
import imageio
import numpy as np
import glob
from PIL import Image
from numpy import asarray
import IPython.display as dsip
reconstruct_file = './reconstruction.gif'
filenames = glob.glob('./images/*.png')
filenames = sorted(filenames)
imgs = [asarray(Image.open(img)) for img in filenames]
imageio.mimsave(reconstruct_file, imgs)
with open(reconstruct_file,'rb') as file:
dsip.display(dsip.Image(file.read()))
generated_file = './generation.gif'
filenames = glob.glob('./generated/*.png')
filenames = sorted(filenames)
imgs = [asarray(Image.open(img)) for img in filenames]
imageio.mimsave(reconstruct_file, imgs)
with open(generated_file,'rb') as file:
dsip.display(dsip.Image(file.read()))
model.eval()
with torch.no_grad():
# sample latent vectors from the normal distribution
latent = torch.randn(128, latent_dims, device=device)
# reconstruct images from the latent vectors
img_recon = model.decoder(latent)
# img_recon = img_recon.cpu()
fig, ax = plt.subplots(figsize=(5, 5))
# plt.imshow(make_grid(img_recon.data[:100].cpu().detach(),10,5))
plt.axis("off")
plt.imshow(make_grid(img_recon.data[:100].cpu().detach().clamp(0.0, 1.0), nrow=10).permute((1, 2, 0)).numpy())
plt.show()