In this lab we will explore MNIST, a classic machine learning data set of images of handwritten digits (i.e., 0, 1, 2, 3, ...). In addition, we will investigate an intuitive, yet powerful learning method called k-nearest neighbors (KNN). Even though the type of data is different from what we've worked with so far, we'll see how to apply familiar tools to the data, namely, scikit learn and matplotlib for machine learning and plotting in python.
Don't forget to fill out the response form.
And if you haven't already, fetch the data by running to following code (it will download into the
%pylab inline import pylab from sklearn.datasets import fetch_mldata DATA_PATH = '~/data' mnist = fetch_mldata('MNIST original', data_home=DATA_PATH)
The MNIST database of handwritten digits is a collection of labeled images that has been used to evaluate machine learning techniques since the '90's. The core application of the MNIST data is to train computer vision systems to recognize handwritten text. The post office, for example, is a major user of such systems---addresses on letters and packages are all photographed, read, and routed digitally, with only a few ambiguous cases verified by a human.
The MNIST data set has also become a reliable benchmark for learning methods. It's small, but not tiny, and the data dimensionality (28x28 pixels) is big enough to cause some "curse of dimensionality" issues. Also, the problem is highly non-linear, meaning a linear classification methods (like linear regression, but for predicting discete categories) don't perform so well on the raw data. The MNIST website reports an extensive list of results obtained by different machine learning models, including neural nets, SVM, nearest neighbors, and others.
The data consists of 60,000 training images and 10,000 test images. Each image is a 28x28 pixel, grayscale picture of a digit written either by a highschool student or an employee of the US Census Bureau. The images have all been preprocessed to be clean and regular: only one digit appears in each image, and it appears directly in the center of the image.
The goal of the benchmark is to fit a model to the training set, and then use that model to predict which digit is in each of the test images. The best results achieve a classification error rate of less than half of one percent. This is often described as the "human error rate," because if you ask people to classify the images, they will also find about 0.5% of them to be comepletely inscutable.
Let's take a look at the data.
MNIST data is easy to load with the built in scikit learn
Make sure you've loaded the data set by running the code at the top of the lab.
The image data itself is in
mnist.data, and it's stored in a numpy n-dimensional array (n=2 in this case).
Numpy arrays are a vector/matrix data structure that provide high performance numerical computing.
We can find out the dimensions of the array with its
There are 70,000 rows and 784 columns. So each row is an image (60,0000 training images plus 10,000 test ones gives 70,000 total), and each column is a pixel value (784 = 28 * 28).
Like Pandas DataFrames, numpy arrays give us a simple interface to summary statistics and subsets of the data.
Numpy arrays are indexed dimension-wise, with each dimension separated by a column:
row = mnist.data[0,:] # First row of the array col = mnist.data[:,0] # First column of the array print row.shape print col.shape
In this syntax, the ":" means "ALL", as in standard python indexing. All of the usual python range indexing syntax works for each dimension of the array. We can compute summary statistics, too:
print row.sum(), row.max(), row.min() print col.sum(), col.max(), col.min() print mnist.data[:10,:] # First ten rows print mnist.data[:,-10:] # Last ten columns
Let's divide the array into two sets, one for training images and one for test:
train = mnist.data[:60000] test = mnist.data[60000:]
Note that we can drop the trailing ",:" when we want to just index the first dimension.
To start, we want to work with just a sample of the training data. Create a sample consisting of every 100th image in
Find the mean value of the 300th column in the sample data set.
test_sample = None # Fix me
One of the nicest things about image data is that it is naturally visualized and understood. First, let's take a look at the raw data in the first image in the data set:
img = mnist.data print img
These are all the pixel values in the image.
We can see some patterns (e.g., the edges are empty), but it's hard to interpret.
In fact, we can can a much better view if we use the
imshow method from matplotlib to display an image:
Annnnnd... it breaks.
What went wrong?
Since images are two dimensional objects,
imshow expects a two dimensional array of data to plot, but the rows of our data array are flat vectors.
Luckily, numpy arrays provide a
reshape method that let's us change the dimensions of our data, (as long as we leave total length the same).
Let's reshape our image data into a 28x28 pixel square and try again:
pylab.imshow(img.reshape(28, 28), cmap="Greys")
It's a zero! Now we're getting somewhere.
imshowto visualize a number of images from
sample. What can you say about how the data set is ordered?
Next, we're interested in uncovering more structure in the MNIST data. For example, we want to be able to answer questions like "how similar are people's handwriting?" and "how distinct are the different digits?" If we get a sense of the variance in the data, and of how tighly it is clustered, we can begin to see a good approach to modeling.
In the spirit of doing the simplest possible thing that might work, we can look at nearest neighbors using simple Euclidean distance between pixels. The assumption is that most digits look the same, so they should have similar values in individual pixels. Let's find out if this assumption is a good one.
%%time from sklearn.neighbors import NearestNeighbors model = NearestNeighbors(algorithm='brute').fit(train)
Note how fast we built a nearest neighbors model, just a few microseconds!
This is because we're using the brute force implementation (
algorithm='brute'), which simple stores the training data to build a model, and does a full pairwise comparison at query time.
Let's query our new model. We can fetch the k nearest neighbors with the
%%time query_img = test _, result = model.kneighbors(query_img, n_neighbors=4)
Notice that query time is significant, even for a single image.
Also, notice that
kneighbors returns two values.
The first, which we will ignore, are the distance values to the nearest neighbors.
The second is a list of indices where we can look up the nearest neighbors in the training set.
With the results, now we can see how we did.
There are four results, as expected. Let's print them out with the utility function below:
# Display several images in a row def show(imgs, n=1): fig = pylab.figure() for i in xrange(0, n): fig.add_subplot(1, n, i, xticklabels=, yticklabels=) if n == 1: img = imgs else: img = imgs[i] pylab.imshow(img.reshape(28, 28), cmap="Greys")
show(query_img) show(train[result,:], len(result))
The neighbors look pretty good! Importantly, they are all zeros. That means that to some extent, at least, our assumption about images of the same digit being "close" to one another in pixel-space is a good one.
We can validate our model in a more rigorous way by using it to predict digits.
Scikit learn provides a class for supervised nearest neighbors fitting called
It is very similar to the
NearestNeighbors class, but it accepts labels when fitting a model, and it provides methods for making label predictions for test data.
The MNIST labels are in
Let's split them into training and test sets as we did with the image data:
train_labels = mnist.target[:60000] test_labels = mnist.target[60000:] test_labels_sample = test_labels[::100]
Next, as before, we fit a model to the training data:
%%time from sklearn.neighbors import KNeighborsClassifier model = KNeighborsClassifier(n_neighbors=4, algorithm='brute').fit(train, train_labels)
%%time # Score the model!
preds = model.predict(test_sample) errors = [i for i in xrange(0, len(test_sample)) if preds[i] != test_labels_sample[i]] for i in errors: pass # Visualize error image and its nearest neighbors
sklearn.metric.confusion_matrixto generate a confusion matrix between model predictions and test labels. Which pair of digits are confused most frequently?
test_sample = test[::10] test_labels_sample = test_labels[::10] def plot_cm(cm): pylab.matshow(np.log(cm)) from sklearn.metrics import confusion_matrix # Compute and plot the confusion matrix for test_sample