Implement an image super-resolution technique in TensorFlow
Single image super-resolution (SR) is a classical computer vision problem that aims at recovering a high-resolution image from a lower resolution image. Extensive research was conduct in this area and with the advance of Deep Learning great results have been achieved.
In this post, we will examine one of the Deep Learning approaches to super-resolution called Super-Resolution Convolutional Neural Network (SRCNN). This technique work end to end by extacting patches from the low resolution image and passing them throw convolutional layers to final map them to higher resolution output pixels, as depicted in the diagram below.
We will implement the SRCNN model in TensorFlow, train it and then test it on a low resolution image.
As an image dataset for training the model, we will be using a Kaggle hosted dataset called Dog and Cat Detection. We will use Kaggle CLI to download this dataset and you need to get your Kaggle API key, alternatively you can manually download the dataset directly from the website.
%%capture
%%bash
pip install -q kaggle
mkdir -p ~/.kaggle
echo '{"username":"KAGGLE_USERNAME","key":"KAGGLE_KEY"}' > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
%%capture
%%bash
kaggle datasets download andrewmvd/dog-and-cat-detection
unzip -q dog-and-cat-detection.zip
To save model checkpoint and other artifacts, we will mount Google Drive the this colab container
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
Import dependencies
import os
import pathlib
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import *
Set a random seed for reproducibility
SEED = 31
np.random.seed(SEED)
We need a function to resize images based on a scale factor, this will be used later in the process to generate low resolution images from a given image
def resize_image(image_array, factor):
original_image = Image.fromarray(image_array)
new_size = np.array(original_image.size) * factor
new_size = new_size.astype(np.int32)
new_size = tuple(new_size)
resized = original_image.resize(new_size)
resized = img_to_array(resized)
resized = resized.astype(np.uint8)
return resized
This function will use the resizing to generate low resolution images by downsizing then upsizing:
def downsize_upsize_image(image, scale):
scaled = resize_image(image, 1.0 / scale)
scaled = resize_image(scaled, scale / 1.0)
return scaled
When we will extract patches, we will slide a window over the original image, and for the image to fit nicely we need to crop it with the following function
def tight_crop_image(image, scale):
height, width = image.shape[:2]
width -= int(width % scale)
height -= int(height % scale)
return image[:height, :width]
The following function is used to extract patches with a sliding window from an input image. The INPUT_DIM
parameter is the height and width of the images as expected by the network
def crop_input(image, x, y):
y_slice = slice(y, y + INPUT_DIM)
x_slice = slice(x, x + INPUT_DIM)
return image[y_slice, x_slice]
Similarly, we need to crop patches from the output images with LABEL_SIZE
the height and width of the output of the network. We also need to pad the patches with PAD
to make sure we are cropping the regions properly
def crop_output(image, x, y):
y_slice = slice(y + PAD, y + PAD + LABEL_SIZE)
x_slice = slice(x + PAD, x + PAD + LABEL_SIZE)
return image[y_slice, x_slice]
Now let's read all image paths
file_patten = (pathlib.Path('/content') / 'images' / '*.png')
file_pattern = str(file_patten)
dataset_paths = [*glob(file_pattern)]
We don't need the entire dataset as this will take longer training, but will sample around 1000 images from it
SUBSET_SIZE = 1000
dataset_paths = np.random.choice(dataset_paths, SUBSET_SIZE)
Here is an example image from the dataset
path = np.random.choice(dataset_paths)
img = plt.imread(path)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7fb5cde796d0>
Here we define some parameters, like the scale for resiping, input and output patch sizes, the amount of padding that need to be added to output patches, and the stride which is the number of pixels we'll slide both in the horizontal and vertical axes to extract patches.
SCALE = 2.0
INPUT_DIM = 33
LABEL_SIZE = 21
PAD = int((INPUT_DIM - LABEL_SIZE) / 2.0)
STRIDE = 14
Now, lets build the dataset by reading the input images, generating a low resolution version, sliding a window on this low resolution image as well as the original image to generate patches for training. We will save the patches to disk and later build a training data generator that will load them from disk in batches.
%%bash
mkdir -p data
mkdir -p training
for image_path in tqdm(dataset_paths):
filename = pathlib.Path(image_path).stem
image = load_img(image_path)
image = img_to_array(image)
image = image.astype(np.uint8)
image = tight_crop_image(image, SCALE)
scaled = downsize_upsize_image(image, SCALE)
height, width = image.shape[:2]
for y in range(0, height - INPUT_DIM + 1, STRIDE):
for x in range(0, width - INPUT_DIM + 1, STRIDE):
crop = crop_input(scaled, x, y)
target = crop_output(image, x, y)
np.save(f'data/{filename}_{x}_{y}_input.np', crop)
np.save(f'data/{filename}_{x}_{y}_output.np', target)
100%|██████████| 1500/1500 [18:00<00:00, 1.39it/s]
We cannot hold all the patches in memory hence we saved to disk in the previous step. Now we need a dataset loader that will load a patch and its label and feed them to the network during traning in batches. This is achieved with the PatchesDataset
class (check this example to learn more about generators - link).
class PatchesDataset(tf.keras.utils.Sequence):
def __init__(self, batch_size, *args, **kwargs):
self.batch_size = batch_size
self.input = [*glob('data/*_input.np.npy')]
self.output = [*glob('data/*_output.np.npy')]
self.input.sort()
self.output.sort()
self.total_data = len(self.input)
def __len__(self):
# returns the number of batches
return int(self.total_data / self.batch_size)
def __getitem__(self, index):
# returns one batch
indices = self.random_indices()
input = np.array([np.load(self.input[idx]) for idx in indices])
output = np.array([np.load(self.output[idx]) for idx in indices])
return input, output
def random_indices(self):
return np.random.choice(list(range(self.total_data)), self.batch_size, p=np.ones(self.total_data)/self.total_data)
Define a batch size based on how much memory available on your GPU and create an instance of the dataset generator.
BATCH_SIZE = 1024
train_ds = PatchesDataset(BATCH_SIZE)
len(train_ds)
888
You can see the shape of the training batches
input, output = train_ds[0]
input.shape, output.shape
((1024, 33, 33, 3), (1024, 21, 21, 3))
The architecture of the SRCNN model is very simple, it has only convolutional layers, one to downsize the input and extract image features and a later one to upside to generate the output image. The following helper function is used to create an instance of the model.
def create_model(height, width, depth):
input = Input(shape=(height, width, depth))
x = Conv2D(filters=64, kernel_size=(9, 9), kernel_initializer='he_normal')(input)
x = ReLU()(x)
x = Conv2D(filters=32, kernel_size=(1, 1), kernel_initializer='he_normal')(x)
x = ReLU()(x)
output = Conv2D(filters=depth, kernel_size=(5, 5), kernel_initializer='he_normal')(x)
return Model(input, output)
To train the network we will use Adam
as optimizer with learning rate decay. Also, as the problem we try to train the network for is a regression problem (we want predict the high resolution pixels) we pick MSE
as a loss function, this will make the model learn the filters that correctly map patches from low to high resolution.
EPOCHS = 12
optimizer = Adam(learning_rate=1e-3, decay=1e-3 / EPOCHS)
model = create_model(INPUT_DIM, INPUT_DIM, 3)
model.compile(loss='mse', optimizer=optimizer)
You can see how the model is small but astonishly it will be able to achieve great results once trained for enough time, we will train it for 12 epochs
model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 33, 33, 3)] 0 conv2d (Conv2D) (None, 25, 25, 64) 15616 re_lu (ReLU) (None, 25, 25, 64) 0 conv2d_1 (Conv2D) (None, 25, 25, 32) 2080 re_lu_1 (ReLU) (None, 25, 25, 32) 0 conv2d_2 (Conv2D) (None, 21, 21, 3) 2403 ================================================================= Total params: 20,099 Trainable params: 20,099 Non-trainable params: 0 _________________________________________________________________
tf.keras.utils.plot_model(model, show_shapes = True, rankdir='LR')
Create a callback that saves the model's weights
checkpoint_path = "training/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)
Now finally, we can train the network
model.fit(train_ds, epochs=EPOCHS, callbacks=[cp_callback])
Epoch 1/12 888/888 [==============================] - ETA: 0s - loss: 258.0937 Epoch 00001: saving model to training/cp.ckpt 888/888 [==============================] - 1735s 2s/step - loss: 258.0937 Epoch 2/12 888/888 [==============================] - ETA: 0s - loss: 105.9775 Epoch 00002: saving model to training/cp.ckpt 888/888 [==============================] - 1428s 2s/step - loss: 105.9775 Epoch 3/12 888/888 [==============================] - ETA: 0s - loss: 102.4195 Epoch 00003: saving model to training/cp.ckpt 888/888 [==============================] - 1364s 2s/step - loss: 102.4195 Epoch 4/12 888/888 [==============================] - ETA: 0s - loss: 98.4859 Epoch 00004: saving model to training/cp.ckpt 888/888 [==============================] - 1347s 2s/step - loss: 98.4859 Epoch 5/12 888/888 [==============================] - ETA: 0s - loss: 97.5308 Epoch 00005: saving model to training/cp.ckpt 888/888 [==============================] - 1352s 2s/step - loss: 97.5308 Epoch 6/12 888/888 [==============================] - ETA: 0s - loss: 96.0889 Epoch 00006: saving model to training/cp.ckpt 888/888 [==============================] - 1347s 2s/step - loss: 96.0889 Epoch 7/12 888/888 [==============================] - ETA: 0s - loss: 94.7550 Epoch 00007: saving model to training/cp.ckpt 888/888 [==============================] - 1355s 2s/step - loss: 94.7550 Epoch 8/12 888/888 [==============================] - ETA: 0s - loss: 93.3618 Epoch 00008: saving model to training/cp.ckpt 888/888 [==============================] - 1332s 1s/step - loss: 93.3618 Epoch 9/12 888/888 [==============================] - ETA: 0s - loss: 93.5235 Epoch 00009: saving model to training/cp.ckpt 888/888 [==============================] - 1346s 2s/step - loss: 93.5235 Epoch 10/12 888/888 [==============================] - ETA: 0s - loss: 92.4781 Epoch 00010: saving model to training/cp.ckpt 888/888 [==============================] - 1356s 2s/step - loss: 92.4781 Epoch 11/12 888/888 [==============================] - ETA: 0s - loss: 91.5945 Epoch 00011: saving model to training/cp.ckpt 888/888 [==============================] - 1348s 2s/step - loss: 91.5945 Epoch 12/12 888/888 [==============================] - ETA: 0s - loss: 91.0127 Epoch 00012: saving model to training/cp.ckpt 888/888 [==============================] - 1336s 2s/step - loss: 91.0127
<keras.callbacks.History at 0x7fb5c9edecd0>
make sure super_resolution
folder exists in Google Drive
%%bash
mkdir -p /content/drive/MyDrive/super_resolution
cp -r training/* /content/drive/MyDrive/super_resolution
save and load the model
path = '/content/drive/MyDrive/super_resolution/model.h5'
model.save(path)
new_model = tf.keras.models.load_model(path)
After train the model for enough time we can evaluate it. Let's pick a random image from the dataset (or you can use anyother image) and transform it into a low resolution image that we can pass to the SRCNN model.
path = np.random.choice(dataset_paths)
image = load_img(path)
image = img_to_array(image)
image = image.astype(np.uint8)
image = tight_crop_image(image, SCALE)
scaled = downsize_upsize_image(image, SCALE)
We need a placeholder where we will put the output patches to create the final image
output = np.zeros(scaled.shape)
height, width = output.shape[:2]
Now we extarct patches from the input image, pass them through the trained model to generate high resolution patch and then put this patch in the right position on the previous placeholder. After processing every patch from the input image we will have a final output image
for y in range(0, height - INPUT_DIM + 1, LABEL_SIZE):
for x in range(0, width - INPUT_DIM + 1, LABEL_SIZE):
crop = crop_input(scaled, x, y)
image_batch = np.expand_dims(crop, axis=0)
prediction = model.predict(image_batch)
new_shape = (LABEL_SIZE, LABEL_SIZE, 3)
prediction = prediction.reshape(new_shape)
output_y_slice = slice(y + PAD, y + PAD + LABEL_SIZE)
output_x_slice = slice(x + PAD, x + PAD + LABEL_SIZE)
output[output_y_slice, output_x_slice] = prediction
Now we can display side by side the low resolution image as well as the resulting output image which is of higher resolution.
figure, axis = plt.subplots(1, 2, figsize=(15, 8))
axis[0].imshow(np.array(scaled,np.int32))
axis[0].set_title('Low resolution image (Downsize + Upsize)')
axis[0].axis('off')
axis[1].imshow(np.array(output,np.int32))
axis[1].set_title('Super resolution result (SRCNN output)')
axis[1].axis('off')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 299.5, 197.5, -0.5)
Very impressive result considering the small model that we trained, as you can see it was able to considerably improve the resolution of the input image.