The following notebook is an example of using a metaheuristic algorithm such as Whale-Optimization Algorithm to perform the task of image segmentation.
We use the otsu's algorithm along with the Whale Optimization algorithm to correctly estimate the thresholds for segmentation.
Please note that this notebook and the corresponding wandb report is inspired from the following papers:
The colab notebook and the report is only a reproduction of the papers mentioned above.
%%capture
!pip install wandb
import numpy as np
import math
import random
import os
import cv2
from scipy import ndimage
import matplotlib
import matplotlib.pyplot as plt
import wandb
%matplotlib inline
# Use a white background for matplotlib figures
matplotlib.rcParams['figure.facecolor'] = '#ffffff'
!wandb login
wandb: You can find your API key in your browser here: https://wandb.ai/authorize wandb: Paste an API key from your profile and hit enter: wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
# mount google drive to access files
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive
Mounted at /gdrive /gdrive
%cd /content/
/content
img =cv2.imread('/gdrive/MyDrive/Colab Notebooks/images/a12.tif')
#img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
Since we will be working with two thresholds or three region-color segmentation it is good to have an image where there are approximately three separate colors and not substantially more. The flip side to having more colors or groups of pixels is simply the algorithm will have largely different threshold values and will approximate certain color bands as the same.
from google.colab.patches import cv2_imshow
cv2_imshow(img)
wandb.init(entity= "ritwik", project="woa-segmentation")
/content/wandb/run-20210410_105419-3s7ehboy
seg_images = []
thresholds = []
The fitness function for any metaheuristic algorithm is the primary design factor. A good fitness function will allow the metaheuristic algorithm to converge faster and perform the task in an effecient way. It is also prudent to keep in mind that a good fitness function will not be iteration heavy. For images this means a great number of pixel-loops will end up taking too much time and compute-resources for calculating the fitness function of even one member.
Here we choose a whale position (two values) and assign them to the thresholds we need to calculate T1 and T2. Finally we calculate the variance of the segmented image using the thresold T1 and T2.
The use of variance as a design choice is due to the fact that variance can essentially tell us how each pixel varies from the neighbouring pixel (or centre pixel) and is used in classify into different regions.
For knowing exactly when to use mean and variance of an image visit the link:
def CostCriteria(image,x):
final_img = image.copy()
T1 = x[0]
T2 = x[1]
h,w,c = image.shape
for i in range(h):
for j in range(w):
for k in range(c):
#print(image[i][j][0])
#a simple if-else ladder to check which pixels belong to
#which band of pixel values.
if (image[i][j][k] >= T2) and (image[i][j][k] < 255):
final_img[i][j][k] = T2
elif (image[i][j][k] >= T1) and (image[i][j][k] < T2):
final_img[i][j][k] = T1
elif (image[i][j][k] >= 0) and (image[i][j][k] < T1):
final_img[i][j][k] = 0
#we log the threshold values to wandb
#append the resultant image to a list of images
#to be logged at a later point in the program
seg_images.append(final_img)
thresholds.append(x)
wandb.log({"threshold-1":x[0]})
wandb.log({"threshold-2":x[1]})
#calculating the variance of the segmented image
return ndimage.variance(final_img)
fitni = lambda x: CostCriteria(img,x) #constructing afunction handle for the fitness function
The Whale Optimization Algorithm is a new optimization technique for solving optimization problems. This algorithm includes three operators to simulate the search for prey, encircling prey, and bubble-net foraging behavior of humpback whales. This is the link to the paper:
import numpy as np
import matplotlib.pyplot as plt
class WOA:
def __init__(self, n_agents, max_iter, lower_b, upper_b, dim, bench_f):
# init args
self.n_agents = n_agents
self.max_iter = max_iter
self.lower_b = lower_b
self.upper_b = upper_b
self.dim = dim
self.bench_f = bench_f
# init problem
self.leader_pos = np.zeros(dim)
self.leader_score = np.inf
self.positions = self.initialize_pos(n_agents, dim, upper_b, lower_b)
def initialize_pos(self, n_agents, dim, upper_b, lower_b):
n_boundaries = len(upper_b) if isinstance(upper_b, list) else 1
if n_boundaries == 1:
positions = np.random.rand(n_agents, dim) * (upper_b - lower_b) + lower_b
else:
positions = np.zeros([n_boundaries, dim])
for i in range(dim):
positions[:,i] = np.random.rand(n_agents, dim) * (upper_b[i] - lower_b[i]) + lower_b[i]
return positions
def forward(self):
t = 0
conv_curve = np.zeros(self.max_iter)
A_arr_final = []
while t < self.max_iter:
fitness, self.positions, self.leader_score, self.leader_pos = self.get_fitness(self.positions, self.leader_score, self.leader_pos)
a_1 = 2 - t * (2 / self.max_iter)
a_2 = -1 + t * (-1 / self.max_iter)
self.positions = self.update_search_pos(self.positions, self.leader_pos, a_1, a_2)
conv_curve[t] = self.leader_score
t += 1
return self.leader_score, self.positions, conv_curve
def get_fitness(self, positions, leader_score, leader_pos):
for i in range(positions.shape[0]):
# adjust agents surpassing bounds
upper_flag = positions[i,:] > self.upper_b
lower_flag = positions[i,:] < self.lower_b
positions[i,:] = positions[i,:] * ((upper_flag + lower_flag) < 1) + self.upper_b * upper_flag + self.lower_b * lower_flag
# objective function
fitness = self.bench_f(positions[i,:])
# update leader
if fitness < leader_score: # change to > if maximizing
leader_score = fitness
leader_pos = positions[i,:]
return fitness, positions, leader_score, leader_pos
def update_search_pos(self, positions_, leader_pos_, a_1, a_2):
positions = positions_.copy()
leader_pos = leader_pos_.copy()
for i in range(positions.shape[0]):
r_1 = np.random.rand()
r_2 = np.random.rand()
A = 2 * a_1 * r_1 - a_1 # Eq. (2.3)
C = 2 * r_2 # Eq. (2.4)
b = 1
l = (a_2 - 1) * np.random.rand() + 1
p = np.random.rand() # p in Eq. (2.6)
for j in range(positions.shape[1]):
if p < 0.5:
if np.abs(A) >= 1:
rand_leader_idx = int(np.floor(self.n_agents * np.random.rand()))
x_rand = positions[rand_leader_idx, :]
d_x_rand = np.abs(C * x_rand[j] - positions[i, j]) # Eq. (2.7)
positions[i, j] = x_rand[j] - A * d_x_rand
else:
d_leader = np.abs(C * leader_pos[j] - positions[i, j]) # Eq. (2.1)
positions[i, j] = leader_pos[j] - A * d_leader
else:
dist_to_leader = np.abs(leader_pos[j] - positions[i, j]) # Eq. (2.5)
positions[i, j] = dist_to_leader * np.exp(b * l) * np.cos(l * 2 * np.pi) + leader_pos[j]
return positions
#population size =30 (number of whales)
#maximum iteration = 10 (number of times the whales will go skrrrrrrr)
#lower bound = 0 (pixel intensity)
#upper bound = 255 (pixel intensity)
#dimension = 2 (number of parameters to be approximated)
woa = WOA(30, 10, 0, 255, 2, fitni) #calling the function WOA with the necessary parameters
best_score, best_pos, conv_curve = woa.forward()
The convergence curve of the WOA algorithm for 10 iterations is plotted to visualize the rate at which the algorithm converges to the global minima of the fitness function.
np.mean(best_pos,axis=0)
array([257.8579688 , 35.08947612])
fig, ax = plt.subplots()
ax.plot(conv_curve)
ax.ticklabel_format(axis='y', style='sci')
ax.set(xlabel='iter', ylabel='leader_score',
title='Objective Space')
plt.show()
wandb.log({"img": [wandb.Image(data, grouping=3)
for data in seg_images]})
table = wandb.Table(data=thresholds, columns = ["threshold-1", "threshold-2"])
#wandb.log({"my_custom_id" : wandb.plot.scatter(table, "threshold-1", "threshold-2")})
wandb.log({"convergence plot": fig})
wandb.finish()
WARNING:root:Only 108 Image will be uploaded. /usr/local/lib/python3.7/dist-packages/plotly/matplotlylib/renderer.py:548: UserWarning: Looks like the annotation(s) you are trying to draw lies/lay outside the given figure size. Therefore, the resulting Plotly figure may not be large enough to view the full text. To adjust the size of the figure, use the 'width' and 'height' keys in the Layout object. Alternatively, use the Margin object to adjust the figure's margins.
VBox(children=(Label(value=' 0.51MB of 0.52MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.98414559645…
/content/wandb/run-20210410_105419-3s7ehboy/logs/debug.log
/content/wandb/run-20210410_105419-3s7ehboy/logs/debug-internal.log
threshold-1 | 255.0 |
_runtime | 660 |
_timestamp | 1618052719 |
_step | 601 |
threshold-2 | 35.50853 |
threshold-1 | ▅▇▇▅███▃▃▁█████████████████▆▇█████████▅█ |
_runtime | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███ |
_timestamp | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███ |
_step | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███ |
threshold-2 | ▄▂▇▇█▁█▁▁▁▃▂▂▁▁▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂ |
A demo function with the same specifications as CostCriteria() is created by the name of CostSegmentation(). The parameters passed into this function is the original image along with the final element of the best_pos list from the WOA algorithm. It is common knowledge that the final element is the element with the best possible parameters from the optimization (duh!).
The two resultant images are plotted to show the segmentation achieved.
def CostSegmentation(image,x):
final_img = image.copy()
T1 = x[0]
T2 = x[1]
h,w,c = image.shape
for i in range(h):
for j in range(w):
for k in range(c):
if (image[i][j][k] >= T2) and (image[i][j][k] < 255):
final_img[i][j][k] = T2
elif (image[i][j][k] >= T1) and (image[i][j][k] < T2):
final_img[i][j][k] = T1
elif (image[i][j][k] >= 0) and (image[i][j][k] < T1):
final_img[i][j][k] = 0
return final_img
img_seg = CostSegmentation(img,np.mean(best_pos,axis=0))
fig, (ax, ax2) =plt.subplots(ncols=2)
ax.imshow(img)
ax2.imshow(img_seg)
plt.show()
Although the authors of the original paper, cherry pick their result and show that the algorithm is quite effective, we can verify that the algorithm is quite ineffective in segmenting even a binary image. Let us take a look at why this happens -
Nonetheless this was a good exercise in investigating performance of WOA algorithm on basic image processing applications and also the correct usage of proper cost criteria.
In future, if there is a follow up blog to this one it will be about investigating various cost functions related to image segmentation and the ability of WOA algorithm to converge on them.
Thank you for reading and Happy Coding. ;)