Author: Andre Scholich
Last update: 2016-04-11
Both libraries are very well-documented and offer a lot of usage examples. For the purpose of this quick introduction I selected one example of each library and combined them to solve the task of handwritten digit recognition. The goal is not to provide any understanding of the underlying analysis tools but instead illustrate the high-level capabilities of modern computer-vision and machine-learning libraries.
In order to perform this task we first need to identify potential digits in a given image. The label image regions example of the skimage library serves as a basis. We now load the necessary modules.
import matplotlib.patches as mpatches
from pylab import plt, np
from scipy import ndimage as ndi
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.io import imread
from skimage.measure import label
from skimage.measure import regionprops
from skimage.morphology import closing, square
from skimage.segmentation import clear_border
from skimage.transform import downscale_local_mean
%matplotlib inline
--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) <ipython-input-1-46c23cba3a7c> in <module> 2 from pylab import plt, np 3 from scipy import ndimage as ndi ----> 4 from skimage.color import rgb2gray 5 from skimage.filters import threshold_otsu 6 from skimage.io import imread ModuleNotFoundError: No module named 'skimage'
Read and show the image we would like to analyse.
original_image = imread('my_digits/my_digits_3.jpg')
plt.imshow(original_image, interpolation='none')
We convert the color image to gray scale according to luminance and invert it. Furthermore, as the image is of high resolution we downscale it to reduce artifacts and increase computation speed.
image = rgb2gray(original_image)
# image = np.mean(original_image, axis=2)
image = downscale_local_mean(image, (5, 5))
image = image / np.max(image) * 256
image = np.max(image) - image
plt.imshow(image, cmap='gray_r', interpolation='none')
<matplotlib.image.AxesImage at 0xc03f4e0>
Next, we create binarized version of the image using the Otsu method and fill the holes in it. We do this in order to create 'blobs' that we can easily recognize. In a last step we clear the border of the image in order to not run in into trouble there.
# binarize
global_thresh = threshold_otsu(image)
binary_global = image > global_thresh
# fill holes
bw = ndi.binary_fill_holes(binary_global)
cleared = bw.copy()
clear_border(cleared)
plt.imshow(cleared, cmap='gray_r', interpolation='none')
<matplotlib.image.AxesImage at 0x12349ba8>
This pre-processed image we now use to label the connected regions of the image. We do this and show the identified regions of interest as overlay. The properties of each region of the labeled image can easily be accessed with regionprops function (maybe familiar from MatLab).
label_image = label(cleared)
# create a larger figure and show the segmentation result
plt.figure(figsize=(15,10))
plt.imshow(image, cmap='gray_r')
for region in regionprops(label_image):
minr, minc, maxr, maxc = region.bbox
# make regions square
maxwidth = np.max([maxr - minr, maxc - minc])
minr, maxr = int(0.5 * ((maxr + minr) - maxwidth)) - 1, int(0.5 * ((maxr + minr) + maxwidth)) + 1
minc, maxc = int(0.5 * ((maxc + minc) - maxwidth)) - 1, int(0.5 * ((maxc + minc) + maxwidth)) + 1
rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor='red', linewidth=2)
plt.gca().add_patch(rect)
plt.show()
Having now identified the regions of interest within our image we now turn to the problem of recognizing the digit within them. We use an example dataset of sklearn that relies on the Pen-Based Recognition of Handwritten Digits Data Set. It is a digit database consisting of 250 samples from 44 (english) writers. Each sample is a 8x8 image of a number.
Following the Recognizing hand-written digits example from sklearn we train a SVM with the data set.
from sklearn import datasets
from sklearn import svm
from skimage.transform import resize
from find_digits_and_predict import show_prediction_result
digits = datasets.load_digits()
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(digits.data, digits.target)
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False)
We rescale our identified regions to match the 8x8 input images of the training set and apply the classifier to them to get a prediction of the written digit.
def show_prediction_result(image, label_image, clf):
size = (8, 8)
plt.figure(figsize=(15, 10))
plt.imshow(image, cmap='gray_r')
candidates = []
predictions = []
for region in regionprops(label_image):
# skip small images
# if region.area < 100:
# continue
# draw rectangle around segmented coins
minr, minc, maxr, maxc = region.bbox
# make regions square
maxwidth = np.max([maxr - minr, maxc - minc])
minr, maxr = int(0.5 * ((maxr + minr) - maxwidth)) - 3, int(0.5 * ((maxr + minr) + maxwidth)) + 3
minc, maxc = int(0.5 * ((maxc + minc) - maxwidth)) - 3, int(0.5 * ((maxc + minc) + maxwidth)) + 3
rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
fill=False, edgecolor='red', linewidth=2, alpha=0.2)
plt.gca().add_patch(rect)
# predict digit
candidate = image[minr:maxr, minc:maxc]
candidate = np.array(resize(candidate, size), dtype=float)
# invert
# candidate = np.max(candidate) - candidate
# print im
# rescale to 16 in integer
candidate = (candidate - np.min(candidate))
if np.max(candidate) == 0:
continue
candidate /= np.max(candidate)
candidate[candidate < 0.2] = 0.0
candidate *= 16
candidate = np.array(candidate, dtype=int)
prediction = clf.predict(candidate.reshape(1, -1))
candidates.append(candidate)
predictions.append(prediction)
plt.text(minc - 10, minr - 10, "{}".format(prediction), fontsize=50)
plt.xticks([], [])
plt.yticks([], [])
plt.tight_layout()
plt.show()
return candidates, predictions
candidates, predictions = show_prediction_result(image, label_image, clf)
So in the end, we were able to identify image regions and written number so some accuracy by basically following the examples of the libraries. I also notices, that my handwritten 4 obviously differs from the usual way.
Finally, let us have a look at what the classifier actually "sees" to appreciate the difficulty of the task.
for idx in range(len(candidates)):
plt.imshow(candidates[idx], cmap='gray_r', interpolation='none')
print "prediction:", predictions[idx]
plt.show()
prediction: [1]
prediction: [9]
prediction: [5]
prediction: [3]
prediction: [2]
prediction: [1]
prediction: [0]
prediction: [9]
prediction: [8]
prediction: [7]
prediction: [6]