This is the second of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to train an image classification model that can recognize handwritten digits.
This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.
// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
// Add the maven dependencies
%maven ai.djl:api:0.5.0
%maven ai.djl:basicdataset:0.5.0
%maven ai.djl:model-zoo:0.5.0
%maven ai.djl.mxnet:mxnet-engine:0.5.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-a
import java.nio.file.*;
import ai.djl.*;
import ai.djl.basicdataset.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.util.*;
import ai.djl.basicmodelzoo.cv.classification.*;
import ai.djl.basicmodelzoo.basic.*;
When training a deep learning network, it is important to first understand the dataset.
A Dataset is a collection of sample input/output pairs for the function represented by your neural network. Each single input/output is represented by a Record. Each record could have multiple arrays of inputs or outputs such as an image question and answer dataset where the input is both an image and a question about the image while the output is the answer to the question.
Because data learning is highly parallelizable, training is often done not with a single record at a time but a Batch of records at a time. This can lead to significant performance gains, especially when working with images.
The dataset we will be using is MNIST, a database of handwritten digits. Each image contains a black and white digit from 0-9 in a 28x28 image. It is commonly used when getting started with deep learning because it is small and fast to train.
Once you understand your dataset, you should create an implementation of the Dataset class. In this case, we provide the MNIST dataset built-in to make it easy for you to use it.
Then, we must decide the parameters for loading data from the dataset. The only parameter we need for MNIST is the choice of Sampler. The sampler decides which and how many element from datasets are part of each batch when iterating through it. We will have it randomly shuffle the elements for the batch and use a batchSize of 32. The batchSize is usually the largest power of 2 that fits within memory.
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
A Model contains a neural network Block along with additional artifacts used for the training process. It possesses additional information about the inputs, outputs, shapes, and data types you will use. Generally, you will use Model once you have fully completed your Block.
In this tutoral, we will use the built-in Multilayer Perceptron Block from the Model Zoo. To learn more, see the previous tutorial: Create Your First Network.
Because images in the MNIST dataset are 28x28 grayscale images, we will create an MLP block with 28 x 28 input. The output will be 10 because there are 10 possible classes (0 to 9) each image could be. For the hidden layers, we have chosen new int[] {128, 64}
by experimenting with different values.
Model model = Model.newInstance();
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
Now, you can create a Trainer
to train your model. The trainer is the main class to orchestrate the training process. Usually, they will be opened using a try-with-resources and closed after training is over.
The trainer takes an existing model and attempts to optimize the parameters inside the model's Block to best match the dataset. Most optimization is based upon Stochastic Gradient Descent (SGD).
Before you create your trainer, we we will need a training configuration that describes how to train your model.
The following are a few common items you may need to configure your training:
Loss
function: A loss function is used to measure how well our model matches the dataset. Because the lower value of the function is better, it's called the "loss" function. The Loss is the only required argument to the modelEvaluator
function: An evaluator function is also used to measure how well our model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimizing the model. Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how your model is doing. If you know of any useful evaluators, we recommend adding them.Device
: The device is what hardware should be used to train your model on. Typically, this is either CPU or GPU. DJL can automatically detect whether a GPU is available. If GPUs are available, it will run on a single GPU by default. If you need to train with multiple GPUs, you need to set devices as : config.setDevices(Devices.getDevices(maxNumberOfGPUs))
.Initializer
: An Initializer
is used to set the initial values of the model's parameters before training. This can usually be left as the default initializer.Optimizer
: The optimizer is the code that updates the model parameters to minimize the loss function. There are a variety of optimizers, most of which offer improvements upon the basic SGD. When just starting, you can use the default optimizer. Later on, Customizing the optimizer can result in faster training.DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.addTrainingListeners(TrainingListener.Defaults.logging());
// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
Before training your model, you have to initialize all of the parameters with default values. You can use the trainer for this initialization by passing in the input shape.
trainer.initialize(new Shape(1, 28 * 28));
Now, we can train the model.
// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
int epoch = 2;
for (int i = 0; i < epoch; ++i) {
int index = 0;
// We iterate through the dataset once during this epoch
for (Batch batch : trainer.iterateDataset(mnist)) {
// During trainBatch, we update the loss and evaluators with the results for the training batch.
trainer.trainBatch(batch);
// Now, we update the model parameters based on the results of the latest trainBatch
trainer.step();
// We must make sure to close the batch to ensure all the memory associated with the batch is cleared quickly.
// If the memory isn't closed after each batch, you will very quickly run out of memory on your GPU
batch.close();
}
// reset training and validation evaluators at end of epoch
trainer.endEpoch();
}
Once your model is trained, you should save it so that it can be reloaded later. You can also add metadata to it such as training accuracy, number of epochs trained, etc that can be used when loading the model or when examining it.
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
model
Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: Run image classification with your model.
You can find the complete source code for this tutorial in the examples project.