Introduction to Cloud Machine Learning with Flask API and CNTK

One of the best ways to operationalize a machine learning system is through an API. In this notebook we show how to deploy scalable CNTK models for image classification through an API. The frameworks used in the solution are:

  • CNTK: Microsoft's Cognitive Toolkit is the deep learning library we used to compute the Convolutional Neural Network (CNN) model that identifies images.
  • Flask is one of the most popular frameworks to develop APIs in python.
  • CherryPy is a lightweight web framework for python. We use it as a web server to host the machine learning application.

Here we present an overview of the application. The main procedure is executed by the CNTK CNN. The network is a pretrained ResNet with 152 layers. The CNN was trained on ImageNet dataset, which contains 1.2 million images divided into 1000 different classes. The CNN is accessible through the flask API, which provides an end point /api/v1/classify_image that can be called to classify an image. CherryPy is is the server framework where the application is hosted. It also balances the load, in such a way that several concurrent queries can be executed. Externally, there is the client that can be any desktop or mobile that sends an image to the application to be analyzed and receives the response.

In [44]:
#load libraries
import os,sys
import pkg_resources
from flask import Flask, render_template, request, send_file
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import wget
import numpy as np
from PIL import Image, ImageOps
from urllib.request import urlretrieve
import requests
from cntk import load_model, combine
from io import BytesIO, StringIO
import base64
from IPython.core.display import display, HTML
import aiohttp
import asyncio
import json
import random

print("System version: {}".format(sys.version))
print("Flask version: {}".format(pkg_resources.get_distribution("flask").version))
print("CNTK version: {}".format(pkg_resources.get_distribution("cntk").version))
System version: 3.5.2 |Continuum Analytics, Inc.| (default, Jul  2 2016, 17:53:06) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
Flask version: 0.12
CNTK version: 2.0rc1

Image classification with a pretrained CNTK model

The first step is to download a pretrained model. CNTK has a wide range of different pretrained models that can be used for image classification.

In [2]:
def maybe_download_model(filename='ResNet_18.model'):
    if(os.path.isfile(filename)):
        print("Model %s already downloaded" % filename)
    else:
        model_name_to_url = {
        'AlexNet.model':   'https://www.cntk.ai/Models/AlexNet/AlexNet.model',
        'AlexNetBS.model': 'https://www.cntk.ai/Models/AlexNet/AlexNetBS.model',
        'VGG_16.model': 'https://www.cntk.ai/Models/Caffe_Converted/VGG16_ImageNet.model',
        'VGG_19.model': 'https://www.cntk.ai/Models/Caffe_Converted/VGG19_ImageNet.model',
        'InceptionBN.model': 'https://www.cntk.ai/Models/Caffe_Converted/BNInception_ImageNet.model',
        'ResNet_18.model': 'https://www.cntk.ai/Models/ResNet/ResNet_18.model',
        'ResNet_50.model': 'https://www.cntk.ai/Models/Caffe_Converted/ResNet50_ImageNet.model',
        'ResNet_101.model': 'https://www.cntk.ai/Models/Caffe_Converted/ResNet101_ImageNet.model',
        'ResNet_152.model': 'https://www.cntk.ai/Models/Caffe_Converted/ResNet152_ImageNet.model'
        }
        url = model_name_to_url[filename] 
        wget.download(url, out=filename)

For this example we are going to use ResNet with 152 layers, which has a top-5 error of 6.71% in ImageNet.

In [58]:
%%time
model_name = 'ResNet_152.model'
IMAGE_MEAN = 0 # in case the CNN rests the mean for the image
maybe_download_model(model_name)
Model ResNet_152.model already downloaded
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 345 µs

Together with the model, we need the classification information. The synsets file maps the output of the network, which is an number between 0 and 999, with the class name.

In [4]:
def read_synsets(filename='synsets.txt'):
    with open(filename, 'r') as f:
        synsets = [l.rstrip() for l in f]
        labels = [" ".join(l.split(" ")[1:]) for l in synsets]
    return labels

labels = read_synsets()
print("Label length: ", len(labels))
print(labels[:5])
Label length:  1000
['tench, Tinca tinca', 'goldfish, Carassius auratus', 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 'tiger shark, Galeocerdo cuvieri', 'hammerhead, hammerhead shark']

Next we are going to prepare some helper functions to read images with PIL and plot them.

In [5]:
def read_image_from_file(filename):
    img = Image.open(filename)
    return img
def read_image_from_ioreader(image_request):
    img = Image.open(BytesIO(image_request.read())).convert('RGB')
    return img
def read_image_from_request_base64(image_base64):
    img = Image.open(BytesIO(base64.b64decode(image_base64)))
    return img
def read_image_from_url(url):
    img = Image.open(requests.get(url, stream=True).raw)
    return img
In [6]:
def plot_image(img):
    plt.imshow(img)
    plt.axis('off')
    plt.show()

Let's test the different input images and plot them.

In [7]:
imagepath = 'neko.jpg'
img_cat = read_image_from_file(imagepath)
plot_image(img_cat)
In [8]:
imagefile = open(imagepath, 'rb')
print(type(imagefile))
img = read_image_from_ioreader(imagefile)
plot_image(img)
<class '_io.BufferedReader'>
In [9]:
imagefile = open(imagepath, 'rb')
image_base64 = base64.b64encode(imagefile.read())
print("String of %d characters" % len(image_base64))
img = read_image_from_request_base64(image_base64)
plot_image(img)
String of 191484 characters
In [10]:
imageurl = 'https://pbs.twimg.com/profile_images/269279233/llama270977_smiling_llama_400x400.jpg'
img_llama = read_image_from_url(imageurl)
plot_image(img_llama)

Once we have the image, the model file and the sysntets, the next step is to load the model and perform a prediction. We need to process the image to swap the RGB channels and resize to the input size of ImageNet, which is 224x224.

In [11]:
%%time
z = load_model(model_name)
CPU times: user 388 ms, sys: 216 ms, total: 604 ms
Wall time: 648 ms
In [12]:
def softmax(vect):
    return np.exp(vect) / np.sum(np.exp(vect), axis=0)
In [55]:
def get_preprocessed_image(my_image, mean_image):
    #Crop and center the image
    my_image = ImageOps.fit(my_image, (224, 224), Image.ANTIALIAS)
    #Transform the image for CNTK format
    my_image = np.array(my_image, dtype=np.float32)
    # RGB -> BGR
    bgr_image = my_image[:, :, ::-1] 
    image_data = np.ascontiguousarray(np.transpose(bgr_image, (2, 0, 1)))
    image_data -= mean_image
    return image_data
In [56]:
def predict(model, image, labels, number_results=5):
    img = get_preprocessed_image(image, IMAGE_MEAN)
    # Use last layer to make prediction
    arguments = {model.arguments[0]: [img]}
    result = model.eval(arguments)
    result = np.squeeze(result)
    prob = softmax(result)
    # Sort probabilities 
    prob_idx = np.argsort(result)[::-1][:number_results]
    pred = [labels[i] for i in prob_idx]
    return pred
 

Now let's predict the class of some of the images

In [59]:
resp = predict(z, img_llama, labels, 2)
print(resp)
resp = predict(z, img_cat, labels, 3)
print(resp)
resp = predict(z, read_image_from_url('http://www.awf.org/sites/default/files/media/gallery/wildlife/Hippo/Hipp_joe.jpg'), labels, 5)
print(resp)
['llama', 'Arabian camel, dromedary, Camelus dromedarius']
['tabby, tabby cat', 'Egyptian cat', 'tiger cat']
['hippopotamus, hippo, river horse, Hippopotamus amphibius', 'hog, pig, grunter, squealer, Sus scrofa', 'warthog', 'wild boar, boar, Sus scrofa', 'piggy bank, penny bank']

Set up Flask API

Let´s start the flask server. The code can be found in the file cntk_api.py. To start it in localhost, first set DEVELOPMENT=True in the file config.py and execute it inside a cntk environment:

source activate my-cntk-env
python cntk_api.py

You will get something like this:

* Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)
* Restarting with stat
* Debugger is active!

First we will test that the API works locally, for that we created a sample web page hello.html that is going to be rendered when you call the API root.

The code for this operation is as simple as this:

@app.route('/')
def hello():
    return render_template('hello.html')

In order to execute the api from this notebook you can use the magic background function executing the file using the binary in the cntk environment:

In [ ]:
%%bash --bg 
/home/my-user/anaconda3/envs/my-cntk-env/bin/python /home/my-user/sciblog_support/Intro_to_Machine_Learning_API/cntk_api.py
In [15]:
res = requests.get('http://127.0.0.1:5000/')
display(HTML(res.text))
CNTK API

The CNTK API works

Auto high 5!

Flask allows for simple routing of services. Let's create the main end point /api/v1/classify_image. The service accepts an image in bytes or in a url. Both requests are converted to a PIL image. In case the request is incorrect the API returns a bad request. The image is analyzed using predict method, that returns the top-5 classification results given the CNN model and the labels. Finally everything is formatted as a json.

@app.route('/api/v1/classify_image', methods=['POST'])
def classify_image():
    if 'image' in request.files:
        cherrypy.log("Image request")
        image_request = request.files['image']
        img = read_image_from_ioreader(image_request)
    elif 'url' in request.json: 
        cherrypy.log("JSON request: {}".format(request.json['url']))
        image_url = request.json['url']
        img = read_image_from_url(image_url)
    else:
        cherrypy.log("Bad request")
        abort(BAD_REQUEST)
    resp = predict(model, img, labels, 5)
    return make_response(jsonify({'message': resp}), STATUS_OK)

Let's first force a bad request:

In [16]:
headers = {'Content-type':'application/json'}
data = {'param':'1'}
res = requests.post('http://127.0.0.1:5000/api/v1/classify_image', data=json.dumps(data), headers=headers)
print(res.text)
{
  "error": "Bad request"
}

Now we are going to use the end point with a image from an URL.

In [17]:
%%time
imageurl = 'https://pbs.twimg.com/profile_images/269279233/llama270977_smiling_llama_400x400.jpg'
data = {'url':imageurl}
res = requests.post('http://127.0.0.1:5000/api/v1/classify_image', data=json.dumps(data), headers=headers)
print(res.text)
{
  "message": [
    "llama", 
    "Arabian camel, dromedary, Camelus dromedarius", 
    "chimpanzee, chimp, Pan troglodytes", 
    "ostrich, Struthio camelus", 
    "patas, hussar monkey, Erythrocebus patas"
  ]
}

CPU times: user 8 ms, sys: 4 ms, total: 12 ms
Wall time: 6.04 s

Finally, we are going to test the API with an image loaded from disk.

In [62]:
%%time
imagepath = 'neko.jpg'
image_request = open(imagepath, 'rb')
files_local = {'image': image_request}
res = requests.post('http://127.0.0.1:5000/api/v1/classify_image', files=files_local)
print(res.text)
{
  "message": [
    "tabby, tabby cat", 
    "Egyptian cat", 
    "tiger cat", 
    "lynx, catamount", 
    "bath towel"
  ]
}

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 771 ms

CNTK API with CherryPy

There are multiple solutions to setup a production API with Flask. One option is using Nginx server and Gunicorn as a load balancer. Here you can find a great example combining these two technologies to serve machine learning models. Another way is to use Apache like in this example. Here a simple flask application is set up using Apache server and Gunicorn.

In this notebook we are going to use CherryPy. It has the following features that the authors announce in their web page:

  • A reliable, HTTP/1.1-compliant, WSGI thread-pooled webserver. WSGI stands for Web Server Gateway Interface. It is a specification that describes how a web server communicates with web applications, and how web applications can be chained together to process one request.
  • CherryPy is now more than ten years old and it is has proven to be very fast and stable. In this benchmark CherryPy is tested against several other solutions.
  • Built-in profiling, coverage, and testing support.
  • Powerful and easy-to-use configuration system.

All these features make CherryPy a good solution for quickly develop production APIs.

The code to set up the server is fairly simple:

def run_server():
    # Enable WSGI access logging via Paste
    app_logged = TransLogger(app)

    # Mount the WSGI callable object (app) on the root directory
    cherrypy.tree.graft(app_logged, '/')

    # Set the configuration of the web server
    cherrypy.config.update({
        'engine.autoreload_on': True,
        'log.screen': True,
        'log.error_file': "cherrypy.log",
        'server.socket_port': PORT,
        'server.socket_host': '0.0.0.0',
        'server.thread_pool': 50, # 10 is default
    })

    # Start the CherryPy WSGI web server
    cherrypy.engine.start()
    cherrypy.engine.block()

To start the server, we need to set DEVELOPMENT=False in the file config.py and execute it inside a cntk environment:

source activate my-cntk-env
python cntk_api.py

You will get something like this:

[16/Apr/2017:17:52:43] ENGINE Bus STARTING
[16/Apr/2017:17:52:43] ENGINE Started monitor thread 'Autoreloader'.
[16/Apr/2017:17:52:43] ENGINE Started monitor thread '_TimeoutMonitor'.
[16/Apr/2017:17:52:43] ENGINE Serving on http://0.0.0.0:5000
[16/Apr/2017:17:52:43] ENGINE Bus STARTED

The first step is to test if the root end point is working:

In [19]:
server_name = 'http://the-name-of-your-server'
port = 5000
In [20]:
root_url = '{}:{}'.format(server_name, port)
In [22]:
res = requests.get(root_url)
display(HTML(res.text))
CNTK API

The CNTK API works

Auto high 5!

Now, as we did before, let's test the classification API with an image from a URL and an image from bytes.

In [23]:
end_point = root_url + '/api/v1/classify_image' 
#print(end_point)
In [24]:
%%time
imageurl = 'https://pbs.twimg.com/profile_images/269279233/llama270977_smiling_llama_400x400.jpg'
data = {'url':imageurl}
headers = {'Content-type':'application/json'}
res = requests.post(end_point, data=json.dumps(data), headers=headers)
print(res.text)
{
  "message": [
    "llama", 
    "Arabian camel, dromedary, Camelus dromedarius", 
    "chimpanzee, chimp, Pan troglodytes", 
    "ostrich, Struthio camelus", 
    "patas, hussar monkey, Erythrocebus patas"
  ]
}

CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 5.69 s
In [25]:
%%time
imagepath = 'neko.jpg'
image_request = open(imagepath, 'rb')
files = {'image': image_request}
res = requests.post(end_point, files=files)
print(res.text)
{
  "message": [
    "Egyptian cat", 
    "tabby, tabby cat", 
    "tiger cat", 
    "Siamese cat, Siamese", 
    "carton"
  ]
}

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 784 ms

Bombardment

Finally let's do a funny part. Now that we have the API set up with CherryPy, let's test how the system performs under a big number of concurrent requests. The first step is to select the request to execute. We are going to use this handsome hippo.

In [36]:
# Get hippo
hippo_url = "http://www.awf.org/sites/default/files/media/gallery/wildlife/Hippo/Hipp_joe.jpg"

fname = urlretrieve(hippo_url, "bhippo.jpg")[0]
img_bomb = read_image_from_file(fname)
plot_image(img_bomb)