%maven ai.djl:api:0.2.0
%maven ai.djl:repository:0.2.0
%maven ai.djl:model-zoo:0.2.0
%maven ai.djl.mxnet:mxnet-engine:0.2.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.2.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
This tutorial uses MXNet engine as its backend. MXNet has different build flavor and it is platform specific. Please read here for how to select MXNet engine flavor.
String classifier = System.getProperty("os.name").startsWith("Mac") ? "osx-x86_64" : "linux-x86_64";
%maven ai.djl.mxnet:mxnet-native-mkl:jar:${classifier}:1.6.0-a
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.mxnet.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
import ai.djl.zoo.cv.classification.*;
This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:
This tutorial uses a pre-trained MXNet resnet18_v1
model.
We use [DownloadUtils.java] for downloading files from internet.
%load DownloadUtils.java
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json", "build/resnet/resnet18_v1-symbol.json", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz", "build/resnet/resnet18_v1-0000.params", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt", "build/resnet/synset.txt", new ProgressBar());
Path modelDir = Paths.get("build/resnet");
Model model = Model.newInstance();
model.load(modelDir, "resnet18_v1");
Translator
¶Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<BufferedImage, Classifications> translator = new ImageClassificationTranslator.Builder()
.setPipeline(pipeline)
.setSynsetArtifactName("synset.txt")
.build();
var img = BufferedImageUtils.fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/kitten.jpg");
img
Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);
classifications
Now, you can load any MXNet symbolic model and run inference.