This is one of the 100 recipes of the IPython Cookbook, the definitive guide to high-performance scientific computing and data science in Python.

8.3. Learning to recognize handwritten digits with a K-nearest neighbors classifier

  1. Let's do the traditional imports.
In [ ]:
import numpy as np
import sklearn
import sklearn.datasets as ds
import sklearn.cross_validation as cv
import sklearn.neighbors as nb
import matplotlib.pyplot as plt
%matplotlib inline
  1. Let's load the digits dataset, part of the datasets module of scikit-learn. This dataset contains hand-written digits that have been manually labeled.
In [ ]:
digits = ds.load_digits()
X =
y =
print((X.min(), X.max()))

In the matrix X, each row contains the $8 \times 8=64$ pixels (in grayscale, values between 0 and 16). The pixels are ordered according to the row-major order.

  1. Let's display some of the images.
In [ ]:
nrows, ncols = 2, 5
for i in range(ncols * nrows):
    ax = plt.subplot(nrows, ncols, i + 1)
    plt.xticks([]); plt.yticks([]);
  1. Now, let's fit a K-nearest neighbors classifier on the data.
In [ ]:
(X_train, X_test, 
 y_train, y_test) = cv.train_test_split(X, y, test_size=.25)
In [ ]:
knc = nb.KNeighborsClassifier()
In [ ]:, y_train);
  1. Let's evaluate the score of the trained classifier on the test dataset.
In [ ]:
knc.score(X_test, y_test)
  1. Now, let's see if our classifier can recognize a "hand-written" digit!
In [ ]:
# Let's draw a 1.
one = np.zeros((8, 8))
one[1:-1, 4] = 16  # The image values are in [0, 16].
one[2, 3] = 16
In [ ]:
plt.imshow(one, interpolation='none');
plt.xticks(); plt.yticks();
In [ ]:

