1. Import data

We use data from the standard MNIST set

In [1]:
require 'torch'
require 'nn'
require 'optim'
mnist = require 'mnist'
In [2]:
fullset = mnist.traindataset()
testset = mnist.testdataset()
In [3]:
fullset
Out[3]:
{
  data : ByteTensor - size: 60000x28x28
  size : 60000
  label : ByteTensor - size: 60000
}

We inspect the data just to get an idea of the content

In [4]:
itorch.image(fullset.data[1])
Out[4]:

In [5]:
fullset.label[1]
Out[5]:
5	

We can split the full dataset into a trainin component and a validation component, which will be used to train hyperparameters.

While doing so, we convert the dataset to double

In [6]:
trainset = {
    size = 50000,
    data = fullset.data[{{1,50000}}]:double(),
    label = fullset.label[{{1,50000}}]
}
In [7]:
validationset = {
    size = 10000,
    data = fullset.data[{{50001,60000}}]:double(),
    label = fullset.label[{{50001,60000}}]
}

2. Create the model

We use a model with a single hidden layer, using a hyperbolic tangent activation, but ask the output to be the same as the input. Through this compression, we hope to learn meaningful features

In [8]:
layer_size = 49
In [9]:
model = nn.Sequential()
In [10]:
model:add(nn.Reshape(28*28))
model:add(nn.Linear(28*28, layer_size))
model:add(nn.Tanh())
model:add(nn.Linear(layer_size, 28*28))
model:add(nn.Reshape(28, 28))

We also define a loss function, using the Euclidean distance

In [11]:
criterion = nn.MSECriterion()

3. Define the descent algorithm

We will make use of the optim package to train the network. optim contains several optimization algorithms. All of these algorithms assume the same parameters:

  • a closure that computes the loss, and its gradient wrt to x, given a point x
  • a point x
  • some parameters, which are algorithm-specific

We define a step function that performs training for a single epoch and returns the current loss value

In [12]:
sgd_params = {
   learningRate = 1e-2,
   learningRateDecay = 1e-4,
   weightDecay = 1e-3,
   momentum = 1e-4
}
In [13]:
x, dl_dx = model:getParameters()
In [14]:
step = function(batch_size)
    local current_loss = 0
    local shuffle = torch.randperm(trainset.size)
    batch_size = batch_size or 200
    
    for t = 1,trainset.size,batch_size do
        -- setup inputs for this mini-batch
        -- no need to setup targets, since they are the same
        local size = math.min(t + batch_size - 1, trainset.size) - t
        local inputs = torch.Tensor(size, 28, 28)
        for i = 1,size do
            inputs[i] = trainset.data[shuffle[i+t]]
        end
        
        local feval = function(x_new)
            -- reset data
            if x ~= x_new then x:copy(x_new) end
            dl_dx:zero()

            -- perform mini-batch gradient descent
            local loss = criterion:forward(model:forward(inputs), inputs)
            model:backward(inputs, criterion:backward(model.output, inputs))

            return loss, dl_dx
        end
        
        _, fs = optim.sgd(feval, x, sgd_params)
        -- fs is a table containing value of the loss function
        -- (just 1 value for the SGD optimization)
        current_loss = current_loss + fs[1]
    end

    return current_loss
end

Before starting the training, we also need to be able to evaluate accuracy on a separate dataset, in order to define when to stop

In [15]:
eval = function(dataset, batch_size)
    local loss = 0
    batch_size = batch_size or 200
    
    for i = 1,dataset.size,batch_size do
        local size = math.min(i + batch_size - 1, dataset.size) - i
        local inputs = dataset.data[{{i,i+size-1}}]
        local outputs = model:forward(inputs)
        loss = loss + criterion:forward(model:forward(inputs), inputs)
    end

    return loss
end

4. Train the model

We are now ready to perform the actual training. After each epoch, we evaluate the accuracy on the validation dataset, in order to decide whether to stop

In [16]:
max_iters = 30
In [17]:
do
    local last_loss = 0
    local increasing = 0
    local threshold = 1 -- how many deacreasing epochs we allow
    for i = 1,max_iters do
        local loss = step()
        print(string.format('Epoch: %d Current loss: %4f', i, loss))
        local validation_loss = eval(validationset)
        print(string.format('Loss on the validation set: %4f', validation_loss))
        if last_loss < validation_loss then
            if increasing > threshold then break end
            increasing = increasing + 1
        else
            increasing = 0
        end
        last_loss = validation_loss
    end
end
Out[17]:
Epoch: 1 Current loss: 1637970,365020	
Out[17]:
Loss on the validation set: 291774,068805	
Out[17]:
Epoch: 2 Current loss: 1386913,534027	
Out[17]:
Loss on the validation set: 256769,458954	
Out[17]:
Epoch: 3 Current loss: 1254450,715137	
Out[17]:
Loss on the validation set: 238316,014047	
Out[17]:
Epoch: 4 Current loss: 1183897,421079	
Out[17]:
Loss on the validation set: 228445,221533	
Out[17]:
Epoch: 5 Current loss: 1145644,729101	
Out[17]:
Loss on the validation set: 223092,028734	
Out[17]:
Epoch: 6 Current loss: 1131293,316987	
Out[17]:
Loss on the validation set: 221420,835633	
Out[17]:
Epoch: 7 Current loss: 1130158,413425	
Out[17]:
Loss on the validation set: 221974,697934	
Out[17]:
Epoch: 8 Current loss: 1119790,048007	
Out[17]:
Loss on the validation set: 219604,589293	
Out[17]:
Epoch: 9 Current loss: 1115697,839897	
Out[17]:
Loss on the validation set: 219448,853021	
Out[17]:
Epoch: 10 Current loss: 1109807,449230	
Out[17]:
Loss on the validation set: 218185,749676	
Out[17]:
Epoch: 11 Current loss: 1104321,193965	
Out[17]:
Loss on the validation set: 217454,653938	
Out[17]:
Epoch: 12 Current loss: 1101056,607462	
Out[17]:
Loss on the validation set: 217027,428803	
Out[17]:
Epoch: 13 Current loss: 1099132,262866	
Out[17]:
Loss on the validation set: 217671,219207	
Out[17]:
Epoch: 14 Current loss: 1102081,644264	
Out[17]:
Loss on the validation set: 217170,654118	
Out[17]:
Epoch: 15 Current loss: 1099611,012229	
Out[17]:
Loss on the validation set: 216871,036385	
Out[17]:
Epoch: 16 Current loss: 1098313,029555	
Out[17]:
Loss on the validation set: 216689,997670	
Out[17]:
Epoch: 17 Current loss: 1097272,843898	
Out[17]:
Loss on the validation set: 216580,349829	
Out[17]:
Epoch: 18 Current loss: 1097090,125116	
Out[17]:
Loss on the validation set: 216512,963884	
Out[17]:
Epoch: 19 Current loss: 1099055,421172	
Out[17]:
Loss on the validation set: 216913,903363	
Out[17]:
Epoch: 20 Current loss: 1098608,016052	
Out[17]:
Loss on the validation set: 216726,320380	
Out[17]:
Epoch: 21 Current loss: 1097503,027540	
Out[17]:
Loss on the validation set: 216608,285022	
Out[17]:
Epoch: 22 Current loss: 1096994,449487	
Out[17]:
Loss on the validation set: 216534,451293	
Out[17]:
Epoch: 23 Current loss: 1096401,283605	
Out[17]:
Loss on the validation set: 216487,939164	
Out[17]:
Epoch: 24 Current loss: 1096329,308006	
Out[17]:
Loss on the validation set: 216458,701079	
Out[17]:
Epoch: 25 Current loss: 1096135,436646	
Out[17]:
Loss on the validation set: 216440,259239	
Out[17]:
Epoch: 26 Current loss: 1096054,278895	
Out[17]:
Loss on the validation set: 216428,871156	
Out[17]:
Epoch: 27 Current loss: 1095721,807950	
Out[17]:
Loss on the validation set: 216422,030223	
Out[17]:
Epoch: 28 Current loss: 1095682,525967	
Out[17]:
Loss on the validation set: 216417,637669	
Out[17]:
Epoch: 29 Current loss: 1095853,565834	
Out[17]:
Loss on the validation set: 216415,127409	
Out[17]:
Epoch: 30 Current loss: 1095705,633747	
Out[17]:
Loss on the validation set: 216413,815042	

Let us test the model loss on the test set

In [18]:
testset.data = testset.data:double()
In [19]:
eval(testset)
Out[19]:
219458,52991679	

5. Visualizing the features

We can try to see which features we have actually learned. To do so, we can take a basis vector in the feature space and encode it back to the image space using the model

In [20]:
linear = model.modules[4]
In [21]:
vec = torch.zeros(layer_size)
vec[1] = 1
In [22]:
translate = nn.Sequential()
translate:add(linear)
translate:add(nn.Reshape(28, 28))
In [23]:
itorch.image(translate:forward(vec))
Out[23]:

We can do the same for all vectors at once

In [24]:
basis = torch.eye(layer_size)
In [25]:
itorch.image(translate:forward(basis))
Out[25]: