In this tutorial, we will demonstrate the complete process of training a MNIST model in Tensorflow and exporting the trained model to ONNX.
Firstly, we can initiate the training script by issuing the command python tf-train-mnist.py
on your terminal. Shortly, we should obtain a trained MNIST model. The training process needs no special instrumentation. However, to successfully convert the trained model, onnx-tensorflow requires three pieces of information, all of which can be obtained after training is complete:
with open("graph.proto", "wb") as file:
graph = tf.get_default_graph().as_graph_def(add_shapes=True)
file.write(graph.SerializeToString())
as_graph_def
does not serialize any information about the shapes of the intermediate tensor and such information is required by onnx-tensorflow. Thus we request Tensorflow to serialize the shape information by adding the keyword argument add_shapes=True
as demonstrated above.Secondly, we freeze the graph. Here, we include quotes from Tensorflow documentation about what graph freezing is:
One confusing part about this is that the weights usually aren't stored inside the file format during training. Instead, they're held in separate checkpoint files, and there are Variable ops in the graph that load the latest values when they're initialized. It's often not very convenient to have separate files when you're deploying to production, so there's the freeze_graph.py script that takes a graph definition and a set of checkpoints and freezes them together into a single file.
Thus here we build the freeze_graph tool in the Tensorflow source folder and execute it with the information about where the GraphProto is, where the checkpoint file is and where to put the frozen graph. One caveat is that you need to supply the name of the output node to this utility. If you are having trouble finding the name of the output node, please refer to this article for help.
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/mnist-tf/graph.proto \
--input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
--output_graph=/tmp/frozen_graph.pb \
--output_node_names=fc2/add \
--input_binary=True
Note that now we have obtained the frozen_graph.pb
with graph definition as well as weight information in one file.
Thirdly, we convert the model to ONNX format using onnx-tensorflow. Using tensorflow_graph_to_onnx_model
from onnx-tensorflow API (documentation available at https://github.com/onnx/onnx-tensorflow/blob/master/doc/API.md).
import tensorflow as tf
from onnx_tf.frontend import tensorflow_graph_to_onnx_model
with tf.gfile.GFile("frozen_graph.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
onnx_model = tensorflow_graph_to_onnx_model(graph_def,
"fc2/add",
opset=6)
file = open("mnist.onnx", "wb")
file.write(onnx_model.SerializeToString())
file.close()
Performing a simple sanity check to ensure that we have obtained the correct model, we print out the first node of the ONNX model graph converted, which corresponds to the reshape operation performed to convert the 1D serial input to a 2D image tensor:
print(onnx_model.graph.node[0])
input: "Placeholder" input: "reshape/Reshape/shape" output: "reshape/Reshape" op_type: "Reshape"
In this tutorial, we continue our demonstration by performing inference using this obtained ONNX model. Here, we exported an image representing a handwritten 7 and stored the numpy array as image.npz. Using our backend, we will classify this image using the converted ONNX model.
import onnx
import numpy as np
from onnx_tf.backend import prepare
model = onnx.load('mnist.onnx')
tf_rep = prepare(model)
img = np.load("./assets/image.npz")
output = tf_rep.run(img.reshape([1, 784]))
print "The digit is classified as ", np.argmax(output)
The digit is classified as 7