#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU (not recommended here) or GPU (if available) # # Model Zoo -- Convolutional Autoencoder with Nearest-neighbor Interpolation (Trained on 10 categories of the Quickdraw dataset) # A convolutional autoencoder using nearest neighbor upscaling layers that compresses 768-pixel Quickdraw images down to a 7x7x8 (392 pixel) representation. # ## Imports # In[2]: import os import time import numpy as np import pandas as pd from PIL import Image import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Dataset # This notebook is based on Google's Quickdraw dataset (https://quickdraw.withgoogle.com). In particular we will be working with an arbitrary subset of 10 categories in png format: # # label_dict = { # "lollipop": 0, # "binoculars": 1, # "mouse": 2, # "basket": 3, # "penguin": 4, # "washing machine": 5, # "canoe": 6, # "eyeglasses": 7, # "beach": 8, # "screwdriver": 9, # } # # (The class labels 0-9 can be ignored in this notebook). # # For more details on obtaining and preparing the dataset, please see the # # - [custom-data-loader-quickdraw.ipynb](custom-data-loader-quickdraw.ipynb) # # notebook. # In[3]: df = pd.read_csv('quickdraw_png_set1_train.csv', index_col=0) df.head() main_dir = 'quickdraw-png_set1/' img = Image.open(os.path.join(main_dir, df.index[99])) img = np.asarray(img, dtype=np.uint8) print(img.shape) plt.imshow(np.array(img), cmap='binary') plt.show() # ### Create a Custom Data Loader # In[4]: class QuickdrawDataset(Dataset): """Custom Dataset for loading Quickdraw images""" def __init__(self, txt_path, img_dir, transform=None): df = pd.read_csv(txt_path, sep=",", index_col=0) self.img_dir = img_dir self.txt_path = txt_path self.img_names = df.index.values self.y = df['Label'].values self.transform = transform def __getitem__(self, index): img = Image.open(os.path.join(self.img_dir, self.img_names[index])) if self.transform is not None: img = self.transform(img) label = self.y[index] return img, label def __len__(self): return self.y.shape[0] # In[5]: # Note that transforms.ToTensor() # already divides pixels by 255. internally BATCH_SIZE = 128 custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.), transforms.ToTensor()]) train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv', img_dir='quickdraw-png_set1/', transform=custom_transform) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) # In[6]: device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) num_epochs = 2 for epoch in range(num_epochs): for batch_idx, (x, y) in enumerate(train_loader): print('Epoch:', epoch+1, end='') print(' | Batch index:', batch_idx, end='') print(' | Batch size:', y.size()[0]) x = x.to(device) y = y.to(device) break # ## Settings # In[7]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") print('Device:', device) # Hyperparameters random_seed = 123 learning_rate = 0.0005 num_epochs = 50 # ### Model # In[8]: ########################## ### MODEL ########################## class Autoencoder(torch.nn.Module): def __init__(self): super(Autoencoder, self).__init__() # calculate same padding: # (w - k + 2*p)/s + 1 = o # => p = (s(o-1) - w + k)/2 ### ENCODER # 28x28x1 => 28x28x4 self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3, 3), stride=(1, 1), # (1(28-1) - 28 + 3) / 2 = 1 padding=1) # 28x28x4 => 14x14x4 self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), # (2(14-1) - 28 + 2) / 2 = 0 padding=0) # 14x14x4 => 14x14x8 self.conv_2 = torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(3, 3), stride=(1, 1), # (1(14-1) - 14 + 3) / 2 = 1 padding=1) # 14x14x8 => 7x7x8 self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), # (2(7-1) - 14 + 2) / 2 = 0 padding=0) ### DECODER # 7x7x8 => 14x14x8 ## interpolation # 14x14x8 => 14x14x8 self.conv_3 = torch.nn.Conv2d(in_channels=8, out_channels=4, kernel_size=(3, 3), stride=(1, 1), # (1(14-1) - 14 + 3) / 2 = 1 padding=1) # 14x14x4 => 28x28x4 ## interpolation # 28x28x4 => 28x28x1 self.conv_4 = torch.nn.Conv2d(in_channels=4, out_channels=1, kernel_size=(3, 3), stride=(1, 1), # (1(28-1) - 28 + 3) / 2 = 1 padding=1) def forward(self, x): ### ENCODER x = self.conv_1(x) x = F.leaky_relu(x) x = self.pool_1(x) x = self.conv_2(x) x = F.leaky_relu(x) x = self.pool_2(x) ### DECODER x = F.interpolate(x, scale_factor=2, mode='nearest') x = self.conv_3(x) x = F.leaky_relu(x) x = F.interpolate(x, scale_factor=2, mode='nearest') x = self.conv_4(x) x = F.leaky_relu(x) x = torch.sigmoid(x) return x torch.manual_seed(random_seed) model = Autoencoder() model = model.to(device) ########################## ### COST AND OPTIMIZER ########################## cost_fn = torch.nn.BCELoss() # torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # ## Training # In[9]: ########################## ### TRAINING ########################## epoch_start = 1 torch.manual_seed(random_seed) model = Autoencoder() model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ################## Load previous # the code saves the autoencoder # after each epoch so that in case # the training process gets interrupted, # we will not have to start training it # from scratch files = os.listdir() start_time = time.time() for epoch in range(epoch_start, num_epochs+1): for batch_idx, (x, y) in enumerate(train_loader): # don't need labels, only the images (features) features = x.to(device) ### FORWARD AND BACK PROP decoded = model(features) cost = F.mse_loss(decoded, features) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 500: print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' %(epoch, num_epochs, batch_idx, len(train_loader), cost)) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # Save model if os.path.isfile('autoencoder_quickdraw-1_i_%d_%s.pt' % (epoch-1, device)): os.remove('autoencoder_quickdraw-1_i_%d_%s.pt' % (epoch-1, device)) torch.save(model.state_dict(), 'autoencoder_quickdraw-1_i_%d_%s.pt' % (epoch, device)) # ## Evaluation # In[10]: get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt model = Autoencoder() model = model.to(device) model.load_state_dict(torch.load('autoencoder_quickdraw-1_i_%d_%s.pt' % (num_epochs, device))) model.eval() torch.manual_seed(random_seed) for batch_idx, (x, y) in enumerate(train_loader): features = x.to(device) decoded = model(features) break ########################## ### VISUALIZATION ########################## n_images = 5 fig, axes = plt.subplots(nrows=2, ncols=n_images, sharex=True, sharey=True, figsize=(18, 5)) orig_images = features.detach().cpu().numpy()[:n_images] orig_images = np.moveaxis(orig_images, 1, -1) decoded_images = decoded.detach().cpu().numpy()[:n_images] decoded_images = np.moveaxis(decoded_images, 1, -1) for i in range(n_images): for ax, img in zip(axes, [orig_images, decoded_images]): ax[i].axis('off') ax[i].imshow(img[i].reshape(28, 28), cmap='binary') # In[11]: get_ipython().run_line_magic('watermark', '-iv')