In this tutorial, you'll walk through the BERT QA model trained by MXNet. You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.
Example:
Q: When did BBC Japan start broadcasting?
Answer paragraph:
BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.
It ceased operations after its Japanese distributor folded.
And it picked the right answer:
A: December 2004
The following command defines the repo that the djl.ai package will be fetched from:
%mavenRepo s3 https://djl-ai.s3.amazonaws.com/dev
Run the following command to load the djl.ai package and its dependencies:
%maven ai.djl:api:0.1.0
%maven ai.djl.mxnet:mxnet-engine:0.1.0
%maven ai.djl:repository:0.1.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.1.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
Due to a bug in java kernel's %maven macro, it fails to resolve dependencies with an explicit classifer. You need to use %%loadFromPOM to load the MxNet package.
Specify the MXNet package you would like to use by changing the <classifier>
tag. The following are the options for Mac and Linux:
<classifier>osx-x86_64</classifier>
<classifier>linux-x86_64</classifier>
%%loadFromPOM
<repositories>
<repository>
<id>djl.ai</id>
<url>https://djl-ai.s3.amazonaws.com/dev</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-mkl</artifactId>
<version>1.6.0</version>
<classifier>osx-x86_64</classifier>
</dependency>
</dependencies>
Import the used libraries by running the following:
import java.io.*;
import java.nio.charset.*;
import java.nio.file.*;
import java.util.*;
import com.google.gson.*;
import com.google.gson.annotations.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.metric.*;
import ai.djl.mxnet.zoo.*;
import ai.djl.mxnet.zoo.nlp.bertqa.*;
import ai.djl.repository.zoo.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.translate.*;
import ai.djl.util.*;
Now that all of the prerequisites are complete, start writing code to run inference with this example.
The model requires three inputs:
First, load the input
var question = "When did BBC Japan start broadcasting?";
var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
"Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
QAInput input = new QAInput(question, resourceDocument, 384);
Then load the model and vocabulary. Create a variable model
by using Model.load(<model_directory>, <model_name>)
to load your model.
After that, use the getArtifact("fileName", function)
method to load the vocabulary and create BertDataFormatter
class to prepare for preprocessing.
Map<String, String> criteria = new ConcurrentHashMap<>();
criteria.put("backbone", "bert");
criteria.put("dataset", "book_corpus_wiki_en_uncased");
ZooModel<QAInput, String> model = MxModelZoo.BERT_QA.loadModel(criteria);
Predictor<QAInput, String> predictor = model.newPredictor();
String answer = predictor.predict(input);
answer
Inference in Deep Learning is the process of predicting the output for a given input based on a pre-defined model. djl.ai abstracts the whole process away from you. It can load the model, perform inference on the input, and provide output. djl.ai also allows you to provide user-defined inputs. The workflow looks like the following:
The red block ("Images") in the workflow is the input that djl.ai expects from you. The green block ("Images
bounding box") is the output that you expect. Since djl.ai does not know what input to expect and what format of output that you prefer, djl.ai provides the Translator
interface so you can define your own
input and output.
The Translator
interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing
component converts the user-defined input objects into an NDList, so that the Predictor
in djl.ai can understand the
input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the
Predictor
. The post-processing block allows you to convert the output from the Predictor
to the desired output
format.
Now, you need to convert the sentences into tokens. You can use BertDataFormatter.tokenizer
to convert questions and answers into tokens. Then, use BertDataFormatter.formTokens
to create Bert Formatted tokens. Once you have properly formatted tokens, use parser.token2idx
to create the indices.
The following code block converts the question and answer defined earlier into bert-formatted tokens and creates word types for the tokens.
// Create token lists for question and answer
List<String> tokenQ = BertDataFormatter.tokenizer(question);
List<String> tokenA = BertDataFormatter.tokenizer(resourceDocument);
System.out.println("Question Token: " + tokenQ);
System.out.println("Answer Token: " + tokenA);
System.out.println("Valid length: " + (tokenQ.size() + tokenA.size()));
Normally, words/sentences are represented as indices instead of Strings for training. They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices. The form tokens also pad the sentence to the required length.
// Create Bert-formatted tokens
List<String> tokens = BertDataFormatter.formTokens(tokenQ, tokenA, 384);
// Convert tokens into indices in the vocabulary
List<Integer> indices = parser.token2idx(tokens);
System.out.println("The indices of tokens: " + indices);
Finally, the model needs to understand which part is the Question and which part is the Answer. Mask the tokens as follows:
[Question tokens...AnswerTokens...padding tokens] => [000000...11111....0000]
// Get token types
List<Float> tokenTypes = BertDataFormatter.getTokenTypes(tokenQ, tokenA, 384);
System.out.println("The type mask for tokens: " + tokenTypes);
To properly convert them into float[]
for NDArray
creation, here is the helper function:
/**
* Convert a List of Number to float array.
*
* @param list the list to be converted
* @return float array
*/
public static float[] toFloatArray(List<? extends Number> list) {
float[] ret = new float[list.size()];
int idx = 0;
for (Number n : list) {
ret[idx++] = n.floatValue();
}
return ret;
}
float[] indicesFloat = toFloatArray(indices);
float[] types = toFloatArray(tokenTypes);
Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing!
Translator
¶You need to do this processing within an implementation of the Translator
interface. Translator
is designed to do preprocessing and post processing. Users are required to define input object and output object. It contains the following two override classes:
public NDList processInput(TranslatorContext ctx, I)
public String processOutput(TranslatorContext ctx, O)
Every translator takes in input, and returns output in the form of generic objects. In this case, the translator takes input in the form of QAInput
(I), and return output as a String
(O). QAInput
is just an object that holds questions and answer; We have prepared the Input class for you.
public class QAInput {
private String question;
private String answer;
QAInput(String question, String answer) {
this.question = question;
this.answer = answer;
}
public String getQuestion() {
return question;
}
public String getAnswer() {
return answer;
}
}
Below is one implementation of the translator we have created. Complete the TODO sections in the processInput
section below. (HINT: use the code snippets in the previous cell to help guide you). You can find the usage for NDManager
.
manager.create(Number[] data, Shape)
manager.create(Number[] data)
The Shape
for data0
and data1
is (num_of_batches, sequence_length). For data2
is just 1.
public class BertTranslator implements Translator<QAInput, String> {
private BertDataFormatter parser;
private List<String> tokens;
private int seqLength;
BertTranslator(BertDataFormatter parser) {
this.parser = parser;
this.seqLength = 384;
}
@Override
public Batchifier getBatchifier() {
return null;
}
@Override
public NDList processInput(TranslatorContext ctx, QAInput input) {
// Pre-processing - tokenize sentence
// TODO: Create token lists for question and answer
List<String> tokenQ = BertDataFormatter.tokenizer(question);
List<String> tokenA = BertDataFormatter.tokenizer(resourceDocument);
// TODO Calculate valid length (length(Question tokens) + length(resourceDocument tokens))
var validLength = tokenQ.size() + tokenA.size();
// TODO: Create Bert-formatted tokens
tokens = BertDataFormatter.formTokens(tokenQ, tokenA, 384);
if (tokens == null) {
throw new IllegalStateException("tokens is not defined");
}
// TODO: Convert tokens into indices in the vocabulary
List<Integer> indices = parser.token2idx(tokens);
// TODO: Get token types
List<Float> tokenTypes = BertDataFormatter.getTokenTypes(tokenQ, tokenA, 384);
NDManager manager = ctx.getNDManager();
// TODO Using the manager created above, create NDArrays for the indices, types, and valid length.
// in that order. The type of the NDArray should all be float
NDArray indicesNd = manager.create(toFloatArray(indices), new Shape(1, 384));
NDArray typesNd = manager.create(toFloatArray(tokenTypes), new Shape(1, 384));;
NDArray validLengthNd = manager.create(new float[]{validLength});
NDList list = new NDList(3);
list.add("data0", indicesNd);
list.add("data1", typesNd);
list.add("data2", validLengthNd);
return list;
}
@Override
public String processOutput(TranslatorContext ctx, NDList list) {
NDArray array = list.head();
NDList output = array.split(2, 2);
// Get the formatted logits result
NDArray startLogits = output.get(0).reshape(new Shape(1, -1));
NDArray endLogits = output.get(1).reshape(new Shape(1, -1));
// Get Probability distribution
NDArray startProb = startLogits.softmax(-1);
NDArray endProb = endLogits.softmax(-1);
int startIdx = (int) startProb.argmax(1, true).getFloat(0);
int endIdx = (int) endProb.argmax(1, true).getFloat(0);
return tokens.subList(startIdx, endIdx + 1).toString();
}
}
Congrats! You have created your first Translator! We have pre-filled the processOutput()
that will process the NDList
and return it in a desired format. processInput()
and processOutput()
offer the flexibility to get the predictions from the model in any format you desire.
With the Translator implemented, you need to bring up the predictor to start making predictions. You can find the usage for Predictor
in the Javadoc. Create a translator and use the question
, resourceDocument
provided previously.
String predictResult = null;
QAInput input = new QAInput(question, resourceDocument);
BertTranslator translator = new BertTranslator(parser);
// TODO: Create a Predictor and predict the output using the predictor
try (Predictor<QAInput, String> predictor = model.newPredictor(translator)) {
predictResult = predictor.predict(input);
}
System.out.println(predictResult);
Based on the input, the following result will be shown:
[december, 2004]
That's it!
You can try with more questions and answers. Here are the samples:
Answer Material
The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.
Question
Q: When were the Normans in Normandy? A: 10th and 11th centuries
Q: In what country is Normandy located? A: france