One factor behind deep learning's success is the availability of a wide range of layers that can be composed in creative ways to design architectures suitable for a wide variety of tasks. For instance, researchers have invented layers specifically for handling images, text, looping over sequential data, performing dynamic programming, etc. Sooner or later you will encounter (or invent) a layer that does not exist yet in DJL. In these cases, you must build a custom layer. In this section, we show you how.
To start, we construct a custom layer (a Block)
that does not have any parameters of its own.
This should look familiar if you recall our
introduction to DJL's Block
in :numref:sec_model_construction
.
The following CenteredLayer
class simply
subtracts the mean from its input.
To build it, we simply need to inherit
from the AbstractBlock
class and implement the forward()
and getOutputShapes()
methods.
%load ../utils/djl-imports
class CenteredLayer extends AbstractBlock {
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDList current = inputs;
// Subtract the mean from the input
return new NDList(current.head().sub(current.head().mean()));
}
@Override
public Shape[] getOutputShapes(Shape[] inputs) {
// Output shape should be the same as input
return inputs;
}
}
Let us verify that our layer works as intended by feeding some data through it.
NDManager manager = NDManager.newBaseManager();
CenteredLayer layer = new CenteredLayer();
Model model = Model.newInstance("centered-layer");
model.setBlock(layer);
Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());
NDArray input = manager.create(new float[]{1f, 2f, 3f, 4f, 5f});
predictor.predict(new NDList(input)).singletonOrThrow();
We can now incorporate our layer as a component in constructing more complex models.
SequentialBlock net = new SequentialBlock();
net.add(Linear.builder().setUnits(128).build());
net.add(new CenteredLayer());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
net.initialize(manager, DataType.FLOAT32, input.getShape());
As an extra sanity check, we can send random data through the network and check that the mean is in fact 0. Because we are dealing with floating point numbers, we may still see a very small nonzero number due to quantization.
NDArray input = manager.randomUniform(-0.07f, 0.07f, new Shape(4, 8));
NDArray y = predictor.predict(new NDList(input)).singletonOrThrow();
y.mean();
Now that we know how to define simple layers,
let us move on to defining layers with parameters
that can be adjusted through training.
This lets us tell DJL what we need to calculate gradients for.
To automate some of the routine work,
the Parameter
class and the ParameterList
provide some basic housekeeping functionality.
In particular, they govern access, initialization,
sharing, saving, and loading model parameters.
This way, among other benefits, we will not need to write
custom serialization routines for every custom layer.
We now have all the basic ingredients that we need
to implement our own version of DJL's Linear
layer.
Recall that this layer requires two parameters:
one for weight and one for bias.
In this implementation, we bake in the ReLU activation as a default.
In the constructor, inUnits
and outUnits
denote the number of inputs and outputs, respectively.
We instantiate a new Parameter
by calling its constructor and passing in
a name, a reference to the block it is to be associated with, and its type which
we can set from ParameterType
.
Then we call addParameter()
in our Linear
's constructor
with the newly instantiated Parameter
and its respective Shape
.
We do this for both weight and bias.
class MyLinear extends AbstractBlock {
private Parameter weight;
private Parameter bias;
private int inUnits;
private int outUnits;
// outUnits: the number of outputs in this layer
// inUnits: the number of inputs in this layer
public MyLinear(int outUnits, int inUnits) {
this.inUnits = inUnits;
this.outUnits = outUnits;
weight = addParameter(
Parameter.builder()
.setName("weight")
.setType(Parameter.Type.WEIGHT)
.optShape(new Shape(inUnits, outUnits))
.build());
bias = addParameter(
Parameter.builder()
.setName("bias")
.setType(Parameter.Type.BIAS)
.optShape(new Shape(outUnits))
.build());
}
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Device device = input.getDevice();
// Since we added the parameter, we can now access it from the parameter store
NDArray weightArr = parameterStore.getValue(weight, device, false);
NDArray biasArr = parameterStore.getValue(bias, device, false);
return relu(linear(input, weightArr, biasArr));
}
// Applies linear transformation
public static NDArray linear(NDArray input, NDArray weight, NDArray bias) {
return input.dot(weight).add(bias);
}
// Applies relu transformation
public static NDList relu(NDArray input) {
return new NDList(Activation.relu(input));
}
@Override
public Shape[] getOutputShapes(Shape[] inputs) {
return new Shape[]{new Shape(outUnits, inUnits)};
}
}
Next, we instantiate the MyLinear
class
and access its model parameters.
// 5 units in -> 3 units out
MyLinear linear = new MyLinear(3, 5);
var params = linear.getParameters();
for (Pair<String, Parameter> param : params) {
System.out.println(param.getKey());
}
Let us initialize and test our Linear
.
NDArray input = manager.randomUniform(0, 1, new Shape(2, 5));
linear.initialize(manager, DataType.FLOAT32, input.getShape());
Model model = Model.newInstance("my-linear");
model.setBlock(linear);
Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());
predictor.predict(new NDList(input)).singletonOrThrow();
We can also construct models using custom layers. Once we have that we can use it just like the built-in dense layer.
NDArray input = manager.randomUniform(0, 1, new Shape(2, 64));
SequentialBlock net = new SequentialBlock();
net.add(new MyLinear(8, 64)); // 64 units in -> 8 units out
net.add(new MyLinear(1, 8)); // 8 units in -> 1 unit out
net.initialize(manager, DataType.FLOAT32, input.getShape());
Model model = Model.newInstance("lin-reg-custom");
model.setBlock(net);
Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());
predictor.predict(new NDList(input)).singletonOrThrow();
LinkedHashMap<String, Parameter>
object in each parameters
attribute.Fast Fourier Transform
.