If you are running this notebook on Google Colab, please make sure to do two things:
First, switch the runtime to a GPU instance. This can be done by clicking on the menu Runtime
and selecting Change runtime type
. In the following popup window, leave the Runtime type
on Python 3, but under Hardware accelerator
choose GPU.
Second, execute the following code cell to prepare the notebook environment and its dependencies.
import sys
if 'google.colab' in sys.modules:
# Clone GitHub repository
!git clone https://github.com/miykael/amld20_classification.git
# Copy files required to run the code
!cp -r 'amld20_classification/downloads' 'amld20_classification/utils.py' .
# Install packages via pip
!pip install -r "amld20_classification/colab-requirements.txt"
# Restart Runtime
import os
os.kill(os.getpid(), 9)
The goal of this hands-on exercise is to provide a general overview of image classifiation and to show you how you can train a model and later use it to classify images according to your own categories.
The whole process is subdivided into 4 sections:
Data Preparation: Collect the data, clean it and prepare it for further analysis.
Data Exploration: Explore the dataset and modify it to your need.
Modeling and Analysis: Model creation, optimization & evaluation.
Results Discussion & Exploration: Investigate model performance and understand results.
Note: The complete code to this hands-on exercise can be found under github.com/miykael/amld20_classification.
First things first, let's initiate a few libraries that we need during our analyis. The underlying mechanisms of these lines of code are not important here, but in short: We're preparing all the functions we need later on and make sure that the plotting of the figures looks nice.
import sys
if 'google.colab' in sys.modules:
%tensorflow_version 2.x
%run utils.py
%matplotlib inline
The goal of this section is to (1) specify the target categories we want to predict, and to make sure that (2) the data is aggregated, (3) cleaned and (4) stored in a accessible data format.
The following list specifies the target categories that we want to use for the image classification. For the first execution of this hands-on exercise it is recommended to leave the list as it is, as each new entry will take some time for the data collection.
At a later stagem, feel free to change this list as you like. As a rule of thumb, I would recommend to keep something between 3-8 target classes.
# List of class labels to use for the image classification
class_labels = [
'brown bear',
'polar bear',
'giant panda',
'red panda',
'lion',
'tiger',
'racoon',
'red fox',
]
In an ideal case you would collect the data yourself, make sure that it is of high quality and can be used for an image classification project. In the scope of this hands-on exercise we don't have the time for that so we will take a few short cuts.
First, we will use the python package Google Images Download to download images of our target classes directly via google's image search service. However, this library only allows the download of the first 100 entries per search term. To bypass this restriction we will augment our intial search term with additional keywords. So, instead of just looking for images of a "brown bear", we will look for "brown bear close up" and "brown bear portrait" images. This additional search terms are specified by the parameter search_suffix
.
# Define additional search temrs to expand the number of searches
search_suffix = 'close up,portrait'
Once the class labels and the additional search terms are defined, we can go ahead and collect the images from the web.
# Download the dataset and store the images on disk
imgs_raw = collect_images(class_labels, suffix=search_suffix)
Collecting data about: brown_bear Collecting data about: polar_bear Collecting data about: giant_panda Collecting data about: red_panda Collecting data about: lion Collecting data about: tiger Collecting data about: racoon Collecting data about: red_fox --- A Total of N=3624 images were collected!
Once all images are downloaded and the filenames are stored in the imgs_raw
variable, we can go ahead and take a look at a few of them.
# Let's take a look at the data we've collected
plot_images(imgs_raw, n_col=10, n_row=5)
Note: All the images above are squared, i.e. have the same number of pixels in width and height. This is done on purpose, to make all photos "equal" and comparable. This restriction means that some of the images that originally were rectangular, were cropped to a squared shape - which explains why some of the images seem to be cutting of the imporant part of the image.
Almost all datasets are to a certain degree "dirty" and need to be cleaned. Cleaning in the context of classification means to:
We don't have any missing values in these dataset (images are either downloaded or not), and will therefore only focus on the first two points.
# Remove duplicates from the dataset
imgs_unique = remove_duplicates(imgs_raw)
Total number of images in the dataset: 3624 Number of duplicates in the dataset: 148
The removal of outliers is a bit more tricky when it comes to images. Ideally you would do this by hand, to make sure that you only keep meaningful images in your dataset, i.e. images that belong to one of your target categories. For this hands-on exercise, a manual check-up is not possible, so let's try a quick shortcut by looking at the color profile of the images.
Each image has a particular color profile, i.e. distribution of red, green and blue (RGB) pixel values. Looking at these particular RGB value distributions might allow us to differentiate photos of animals from heavily photoshopped images, gray colored images or graphical logos. Let's take a closer look at a few examples.
# Plot images together with their RGB color distributions
plot_images(imgs_unique, n_col=8, n_row=2, show_histogram=True)
To remove outlier images from the dataset, we will take a look at the RGB color profile of each image and remove images that either are only in gray colors or that have sudden extreme spikes in the RGB profile (i.e. images that have huge amounts of pixels with the exact same color value).
# Remove outliers from the dataset
imgs_clean, imgs_outlier = remove_outliers(imgs_unique)
HBox(children=(FloatProgress(value=0.0, max=3476.0), HTML(value='')))
Total number of images in the dataset: 3476 Number of outliers in the dataset: 239
Now that we separated the dataset into clean and outlier images, let's take a closer look at them:
# Plot some outlier images
plot_images(imgs_outlier, n_col=10, n_row=5)
# Plot some clean images
plot_images(imgs_clean, n_col=10, n_row=5)
As a last step in the data preparation section, we need to finalize the actual dataset and store the relevant image information in useful variables, conventionally called X
and y
.
The image dimension of 32 in the next cell was chosen as a trade-off between speed and image detail. To better understand the effect of this trade-off we recommend you to explore different values for img_dim=32
. What happens if you set this number to 16, 64 or even 128?
Note: Changing this parameter will also have an effect on the computation time. For this reason we recommend that during the first walk through of this hands-on that you leave this parameter at img_dim=32
.
# Create dataset
X_pixel, y_pixel, metainfo = create_dataset(imgs_clean, class_labels, img_dim=32)
Found 3237 images belonging to 8 classes.
HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))
Dataset was augmented to a total of N=6474 images through means of: Image rotation, flipping, shifting, zooming and brightness variation. Images resampled to a resolution of 32 x 32!
Note 1: In machine learning, the variable X
is usually used as the dataframe, i.e. the feature matrix that contains the detailed information about our samples. The variable y
is used as the target variable, in this case the label for the class.
Here, we've added an additional variable called metainfo
that contains relevant information about the dataset (e.g. size of dataset, class labels, image filenames, etc.) which is unique to this exercise, and will be needed in the code further below.
Note 2: The function create_dataset
above also contains the parameter img_dim=32
, which specified to which pixel resolution the images should be resampled. In this case, 32
means that all images are resized to a pixel resolution of 32 x 32. Downsampling images from their original size to a size of 32 x 32 reduces the computation time, and makes sure that all images have the same size.
Once the dataset is prepared and cleaned, it is important that we make sure that we understand its properties. This can be done by doing something called an Explorative Data Analysis (EDA). Each of our investigations during the EDA should lead to an observation, followed by a decision.
So, first things first, let's take a closer look at our dataset and describe it's basic properties.
# How many images per class do we have?
plot_class_distribution(y_pixel, metainfo)
Observation: The classes are not equally distributed. Some classes have more than twice as many images than others.
Decision: The unbalanced occurence of the different classes needs to be considered during the data modeling process.
# What does the average image per target category look like?
plot_class_average(X_pixel, y_pixel, metainfo)
Observation: Some classes like giant_panda
, polar_bear
and red_panda
have very unique color characteristics.
Decision: Let's take a closer look at the RGB color distribution of these class average images.
# What does the average RGB color profile look like per class?
plot_class_RGB(X_pixel, y_pixel, metainfo)
Observation: The individual RGB color distributions of almost all classes look unique.
Decision: We should use the RGB color profile as a data feature for the image classification.
From the previous section, we know that the RGB color profile of each image might be useful to identify the image's class. For this reason, let's take the first dataset X_pixel
and create a second dataset called X_rgb
that consists only of these RGB color profiles.
# Extract RGB color profiles for each image individually
X_rgb, y_rgb = extract_RGB_features(y_pixel, metainfo)
HBox(children=(FloatProgress(value=0.0, max=6474.0), HTML(value='')))
In addition to the RGB color profile as a new data feature, let's also try to extract meaningful features from our images by using a state of the art neural network called MobileNetV2. Note, the neural network that we will use for this feature extraction will be downloaded directly from Tensorflow Hub.
# Extract features according to MobileNetV2 (Convolutional Neural Network)
X_nn, y_nn = extract_neural_network_features()
Building model. Found 3237 images belonging to 8 classes. Extracting features.
HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))
Dataset was augmented to a total of N=6474 images through means of: Image rotation, flipping, shifting, zooming and brightness variation.
Before we continue with the next section, let's quickly recap the three different ways in which we've represented our images:
X_pixel
: This dataset contains the original data in its raw form. Each value is a RGB pixel value in the original image.X_rgb
: This dataset contains the RGB color profile of each image.X_nn
: This dataset contains an 1280-dimensional feature vectore per image, extracted with a convolutional neural network called MobileNetV2.To better understand what this actually means, let's randomly select an image and plot these three representations next to each other:
plot_recap(X_pixel, X_rgb, X_nn)
Once the data is (1) collected, cleaned, and prepared and (2) meaningfull features are extracted - we are ready for the data modeling. In the following section we will do the model fitting and optimization once for each of these three dataset, i.e. X_pixel
, X_rgb
and X_nn
.
More indepth information about the classifier: The model used in this hands-on example is a Ridge Regression Classifier. Ridge regression has the advantage of being rather quick, while also allowing us the fine tuning of a regularization parameter alpha for optimal data fitting. During model optimization, we will be using a crossvalidation approach with 4 folds to fine tune the hyperparameter. To allow to test for generalization, the dataset is initially split into a training and test set according to a 50/50 ratio. Because of the unequal distribution of the class labels throughout the dataset a weighted balancing approach during model optimization is applied.
X_pixel
¶# Train a model fit to the data in pixel format
model_pixel = model_fit(X_pixel, y_pixel)
Fitting 4 folds for each of 25 candidates, totalling 100 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers. [Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 9.2s [Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 17.5s finished
Model trained for 18.01s total and reached an accuracy of: 42.39%
# Investigation of the model performance
check_model_performance(model_pixel, metainfo)
X_rgb
¶# Train a model fit to the data in rgb format
model_rgb = model_fit(X_rgb, y_rgb)
Fitting 4 folds for each of 25 candidates, totalling 100 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers. [Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 1.1s [Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 2.6s finished
Model trained for 2.66s total and reached an accuracy of: 36.02%
# Investigation of the model performance
check_model_performance(model_rgb, metainfo)
X_nn
¶# Train a model fit to the data in neural network format
model_nn = model_fit(X_nn, y_nn)
Fitting 4 folds for each of 25 candidates, totalling 100 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers. [Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 1.3s [Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 3.3s finished
Model trained for 3.43s total and reached an accuracy of: 99.07%
# Investigation of the model performance
check_model_performance(model_nn, metainfo)
Once the model was optimized and the general model performance was investigated, we should take a closer look at the results: Why did the model predict a given class? With what certainty was this decision made? Which images did the model missclassify and why? Which images are easy or difficult for our model? etc.
So, let's investigate the results and why a model chose one class over the others. For this, let's select one of the three models to investigate.
The model_nn
should lead to the best results, so we should investigate its results first. Nonetheless, try to investigate the model performance for the model_pixel
and model_rgb
as well by changing the variable name in the following cell and observing the output changes in the following two cells.
# Choose which model to investigate: 'model_pixel', 'model_rgb' or 'model_nn'
model = model_nn
Let's see which images the model did correctly classify and how certain it was:
# Plot correct predictions
investigate_predictions(model, metainfo, show_correct=True, nimg=7)
Let's see which images the model did not classify correctly and why it was wrong:
# Plot wrong predictions
investigate_predictions(model, metainfo, show_correct=False, nimg=7)
The initial training of the model can be very computational intensive. But once the model is trained, using it to predict the class of new images goes very quickly.
So, let's take an image from the web and see how well our trained model performs!
# Predict the class properties of an image from the web
img_url = 'https://c402277.ssl.cf1.rackcdn.com/photos/18134/images/hero_small/Medium_WW226365.jpg?1574452099'
predict_new_image(img_url, model_nn, metainfo)
Downloading image. Feature Extraction. Plotting report.
What about an image that is only slightly related to the target categories?
# Predict the class properties of an image from the web
img_url = 'https://live.staticflickr.com/7279/7017467025_8807cc82f6_b.jpg'
predict_new_image(img_url, model_nn, metainfo)
Downloading image. Feature Extraction. Plotting report.
And if you always wondered what kind of animal you are, just plug in your image below:
# Predict the class properties of another image from the web
img_url = 'https://d33wubrfki0l68.cloudfront.net/067d6b185769f031404e927b1b70de6d5ece3e0a/d82f4/michael.1164620e.jpg'
predict_new_image(img_url, model_nn, metainfo)
Downloading image. Feature Extraction. Plotting report.
Choose your own image from the internet by storing its URL under the variable img_url
and run the following cell.
img_url = 'www.YOUR_OWN_IMAGE.jpg'
predict_new_image(img_url, model_nn, metainfo)