In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

MNIST 99.82%

Here we describe our experiment with the venerable MNIST dataset of hand-written digits.

  • We reliably get a classification accuracy of 99.82 percent on the ten thousand test images, after training on the sixty-thousand training images. This is with a committee of 35 neural net classifiers. Single-net classification accuracy is usually in the mid 99.7 to 99.8 range.

  • Each net trains in about three minutes on our GTX-1080ti GPU, with only ten presentations of the training set (ten epochs).

  • Along with this notebook, we release the source code of our system, written in Python using PyTorch, so others can reproduce or use our results.

Data

In [2]:
from mnist import *
trainset = mnist_trainset(heldout=0)
testset = mnist_testset()

Our training dataset includes on-the-fly augmentation. Each time an image is retrieved, it is randomly cropped to move it around a little, and also subjected to a small amount of random rotation, and then a small amount of elastic distortion is applied.

In [3]:
# Here we take the first image in the train dataset and display it
# To see the effect of the augmentation, you can simply refresh this
# cell repeatedly with Ctrl-Enter.
image, label = trainset[0]
show_image(image, interpolation='nearest')
In [4]:
# We set up to retrieve batches at a time.
batcher = Batcher(trainset, bs=144)
In [5]:
# Display a batch of (augmented) images.  You can refresh 
# this cell with Ctrl-Enter to see additional batches.
images, labels = next(batcher)
plot_images(images)

Model

Below is the model we use for the MNIST classifier, expressed in PyTorch.

  • It is a fairly conventional multilayer convolutional neural net with some Resnet-like elements.
  • The Residual blocks are ordinary 3x3 convolutions with a shortcut connection going around.
  • It makes heavy use of batchnorm, perhaps to a fault.
  • It uses old-school maxpool instead of strided convolutions, but we haven't tried the latter.
  • It has 7 convolutional layers, including the Residual blocks, which end with 512 4x4 activations.
  • These are averaged into a vector of 512 activations, and then a single small dense layer converts those to ten logits output.
In [6]:
class Residual(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.bn = nn.BatchNorm2d(d)
        self.conv3x3 = nn.Conv2d(in_channels=d, out_channels=d,
                                 kernel_size=3, padding=1)

    def forward(self, x):
        x = self.bn(x)
        return x + F.relu(self.conv3x3(x))

def _mnist_model():
    "Returns an initialized but untrained model for MNIST."
    return nn.Sequential(
               nn.Conv2d(in_channels=1, out_channels=128, kernel_size=5, padding=2),
               nn.ReLU(),

               Residual(128),
               nn.MaxPool2d(2),
               Residual(128),

               nn.BatchNorm2d(128),
               nn.Conv2d(128, 256, 3, padding=1),
               nn.ReLU(),
               nn.MaxPool2d(2),
               Residual(256),

               nn.BatchNorm2d(256),
               nn.Conv2d(256, 512, 3, padding=1),
               nn.ReLU(),
               nn.MaxPool2d(2, ceil_mode=True),
               Residual(512),

               nn.BatchNorm2d(512),
               nn.AvgPool2d(kernel_size=4),
               Flatten(),

               nn.Linear(512,10),
               # Softmax provided during training.
           )

Training

We do a learning rate sweep (like the one used in the fastai library) and plot the learning rate vs the training loss.

In [7]:
trainer = Trainer(mnist_model(), trainset, mnist_testset(augmented=True))
lr_find(trainer, p=0.90, wd=0, bs=100);
stopping at step 378

Based on this, we choose a maximum learning rate of 4e-4, which is about as high as we can go before the curve flattens out and eventually becomes bumpy.

We somewhat arbitrarily choose a factor of 10 and let the learning rate vary between 4e-4 and 4e-5. We use a conventional exponentially-falling learning rate schedule, starting at the maximum and ending at the minimum. We set the momentum to 0.9 and leave it there.

Finally, at the end we train for one additional epoch and rapidly annihilate ("towards nothing") the learning rate by several orders of magnitude.

In [8]:
def annihilate(cb, epochs=1):
    "Quickly reduce learning rate towards zero."
    lr_start = cb.lrs[-1]
    batches = epochs_to_batches(cb.trainer.train_set, epochs, cb.bs)
    lr = cos_interpolator(lr_start, lr_start/1e3, batches)
    return cb.trainer.train(lr=lr, p=cb.moms[-1], wd=cb.wd, epochs=epochs, bs=cb.bs, callback=cb)

def train_one(trainset, testset, callback=None):
    "Train MNIST model according to desired parameters and schedule."
    epochs, lr_max, bs, wd = (10, 4e-4, 100, 0.0)

    model = mnist_model()
    trainer = Trainer(model, trainset)
    lr = exp_interpolator(lr_max, lr_max/10, epochs_to_batches(trainset, epochs, bs))
    
    cb = None if callback is None else callback(trainer)
    cb = trainer.train(lr, p=0.90, epochs=epochs, bs=bs, wd=wd, callback=cb)
    cb = annihilate(cb)

    acc, lss = accuracy(Classifier(model), ds=testset, include_loss=True)
    print(f"test set: loss = {lss:.3g}, accuracy = {percent(acc)}")
 
    return model, cb
In [9]:
m1, cb = train_one(trainset, testset)
100.00% [6000/6000 02:45<00:00]
100.00% [600/600 00:17<00:00]
test set: loss = 0.00754, accuracy = 99.76%

This net trained in about three minutes and the single-net performance is typically near the top of the systems listed on Yann LeCun's MNIST page (see References).

In [10]:
# The callback value returned, cb, holds a history of the training
# session, including the learning rates used during training and the
# training loss at each batch.  (If we give trainer.train() a
# ValidationCallback parameter, it will additionally collect validation
# loss and other data.)
cb.plot_lr()
cb.plot_loss(start=1000, halflife=60)
In [11]:
# Let's train four more.
m2, cb2 = train_one(trainset, testset)
m3, cb3 = train_one(trainset, testset)
m4, cb4 = train_one(trainset, testset)
m5, cb5 = train_one(trainset, testset)
100.00% [6000/6000 02:51<00:00]
100.00% [600/600 00:17<00:00]
test set: loss = 0.00849, accuracy = 99.72%
100.00% [6000/6000 02:52<00:00]
100.00% [600/600 00:17<00:00]
test set: loss = 0.00863, accuracy = 99.71%
100.00% [6000/6000 02:52<00:00]
100.00% [600/600 00:17<00:00]
test set: loss = 0.00644, accuracy = 99.76%
100.00% [6000/6000 02:54<00:00]
100.00% [600/600 00:17<00:00]
test set: loss = 0.008, accuracy = 99.75%
In [12]:
# Wrap each of the trained models into a Classifier object.
classifiers = [Classifier(m) for m in [m1, m2, m3, m4, m5]]
# And combine those classifiers into a committee classifier.
voter = VotingClassifier(classifiers)

# Let's see how it does!
acc = accuracy(voter, ds=testset)
print(f"Committee of {len(classifiers)} accuracy: {percent(acc)}")
Committee of 5 accuracy: 99.81%

For a committee of five, we typically see at least 99.79 percent, which is better than the best listed on Yann LeCun's MNIST page (last updated circa 2012).

In [19]:
# Let's see the items in the test set that our committee-of-five got wrong.
show_mistakes(voter, testset);

For convenience, we wrote a simple Python script which trains a large population of models and writes them to disk, then chooses random committees from it, with a committee size we specify.

That way we can easily see how committees of different size perform. It seems conventional to use a committee of 35, so that is what we use to justify our claim of 99.82 percent accuracy on the test set.

In [20]:
%run committee.py --population 250 --committee 5 --trials 4
Population of 250: training is complete.
1 of 4: Committee of 5 accuracy: 99.80%
2 of 4: Committee of 5 accuracy: 99.77%
3 of 4: Committee of 5 accuracy: 99.81%
4 of 4: Committee of 5 accuracy: 99.80%
mean: 99.80% (0.99795)
<Figure size 432x288 with 0 Axes>
In [21]:
%run committee.py --population 250 --committee 15 --trials 5
Population of 250: training is complete.
1 of 5: Committee of 15 accuracy: 99.82%
2 of 5: Committee of 15 accuracy: 99.82%
3 of 5: Committee of 15 accuracy: 99.83%
4 of 5: Committee of 15 accuracy: 99.84%
5 of 5: Committee of 15 accuracy: 99.82%
mean: 99.83% (0.99826)
In [25]:
%run committee.py --population 250 --committee 35
Population of 250: training is complete.
1 of 7: Committee of 35 accuracy: 99.84%
2 of 7: Committee of 35 accuracy: 99.82%
3 of 7: Committee of 35 accuracy: 99.81%
4 of 7: Committee of 35 accuracy: 99.82%
5 of 7: Committee of 35 accuracy: 99.82%
6 of 7: Committee of 35 accuracy: 99.83%
7 of 7: Committee of 35 accuracy: 99.83%
mean: 99.82% (0.998243)

So that is our claimed 99.82% accuracy on the test set, with a committee of 35.

Just for fun, let's see how well it does if we throw every single net from the population of trained nets--all 250 of them--into a big committee and classify the test set.

In [26]:
%run committee.py --population 250 --committee 250
Population of 250: training is complete.
1 of 1: Committee of 250 accuracy: 99.84%
mean: 99.84% (0.9984)

A little bit better! Let's see what the mistakes are:

In [27]:
filenames = [ f"model{n+1}.pt" for n in range(250) ]
classifiers = [ Classifier(read_model(mnist_model(), f)) for f in filenames ]
big_committee = VotingClassifier(classifiers)
mistakes = show_mistakes(big_committee, mnist_testset());
In [28]:
# This is our favorite MNIST image, taken from above.
img, label = testset[mistakes[7][0]]
show_image(img, figsize=(3,3))

Thoughts

We are pleased to discover that it is possible to get very good results on MNIST using fast training of neural nets--ten epochs of training plus one epoch to annihilate the learning rate. These single nets are near the top of what has been published, and when combined into committees they significantly exceed prior results published on the MNIST site or elsewhere (that we know of).

For regularization, we rely on batchnorm and, probably, the fact that we do not train for very long. Adding dropout to the small dense layer at the end seemed to have little or no effect.

We do not use any weight decay. We implemented weight decay independent of the Adam optimizer's version (which we set to zero) and tried various amounts. Zero seemed about right, perhaps because we are not training for very long and the weights do not rise very much anyway. It is possible that a small amount of weight decay would be appropriate even for our quick training, but we have not found it yet.

The data augmentation made a significant difference, taking our single-net performance from the 99.60's to 99.70's range. We suspect that this works primarily by keeping the network from overfitting to the one-hot labels and training images too quickly (as opposed to, for example, learning more robust representations).

Interestingly, the network never sees the unaugmented data during training. Yet experiments indicate that it is still starting to "overfit" on this data that it has never seen at the end of training--it classifies the unaugmented training data (which it has never seen) better than the next epoch of augmented training data.

Consistent with this, at the end of our ten epochs of training, we have not yet reached perfect accuracy on the training set. We are certainly doing "early stopping" by an indirect means, and this is part of why we think the data augmentation has more to do with "early stopping" than it has to do with robust representations.

Todd Doucet
Pittsburgh, January 31, 2019

References

Jeremy Howard's fast.ai site and fastai library
The code we wrote is all in native PyTorch, but many of the ideas and some of the style is inspired by the new fastai 1.0 library. We think the library is wonderful, and Jeremy's courses are a treasure. The only reason we didn't use fastai directly is that it was just too much fun to write what we needed. Also, this was an exercise in learning PyTorch.

Yann LeCun's MNIST page
It looks like the results have not been updated since about 2012. So we do not claim that our result is the best out there, only that we think we have a good and interesting result.

Regularization of neural networks using DropConnect
Li Wan, Matthew Zeiler, Sixin Zhang, Yann LeCun, and Rob Fergus
This paper's focus is not on MNIST, but they get a better result on MNIST than the best result on Yann LeCun's site. (But not as good as the result we get here, and they use many more epochs in training.)