In [17]:
import tensorflow as tf

tf.reset_default_graph()

# define simple graphs
a = tf.Variable([3.], dtype=tf.float32, name='a')
b = tf.placeholder(tf.float32, shape=(), name='input')
c = tf.multiply(a, b)
d = tf.multiply(c, c, name='output')
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(d, feed_dict={b:2}))

[36.]


# The following define the basic graph and use a session to compute the value¶

In [35]:
import tensorflow as tf

tf.reset_default_graph()

# case1 normal save and restore
# define simple graphs
a = tf.Variable([3.], dtype=tf.float32, name='a')
b = tf.placeholder(tf.float32, shape=(), name='input')
c = tf.multiply(a, b, name='output_0')
d = tf.multiply(c, c, name='output')
init = tf.global_variables_initializer()

# session will bind to the global default graph
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
print(sess.run(d, feed_dict={b:2}))
saver.save(sess, './tmp/model.ckpt')
# then under the directory ./tmp you will find the two files
# model.ckpt.meta : The definition of graph
# model.ckpt.data-00000-of-00001 : The data (the value for the nodes)

[36.]


# The first way to restore the model (Defined the new set of variables)¶

In [36]:
# Then we have two ways to restore the graph

tf.reset_default_graph()

# This cell show the first case is we define a complete same graph with the original one
# and load the model.ckpt.data's value back to this new graph
# the name: 'a', 'input', and 'output' should be consistent with original graph
j = tf.Variable([3.], dtype=tf.float32, name='a')
k = tf.placeholder(tf.float32, shape=(), name='input')
l = tf.multiply(j, k)
m = tf.multiply(l, l, name='output')
init = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, './tmp/model.ckpt')
# ok now we can test with the new samples to the placeholder!
# the answer for the following should be: (3*3) * (3*3)
print(sess.run(m, feed_dict={k:3}))

# You may found a little trivial for this example, but this method is useful for re-training the same
# graph. Since usually we'll define a inference model as a function, we don't need to manually create
# every nodes on as in this example.

# something like:
# logits = inference(X)
# sess.run(logits, feed_dict={X: 123})

INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
[81.]


# The second way to restore the model (Load the graph by DEF)¶

In [37]:
tf.reset_default_graph()

# This cell we want to introduce another way for restoring the model
# This method we don't want to define a set of new nodes(j, k, l, m) but just load
# graph defined before into the current session and restore the data back to the loading
# nodes of graph

# we import the meta file which is the graph for the model
tf.train.import_meta_graph('./tmp/model.ckpt.meta')
saver = tf.train.Saver()
with tf.Session() as sess:
# now we want restore all the values to the graph
saver.restore(sess, './tmp/model.ckpt')
# Before we use sess.run, we need to got the output tensor
# and input tensors

# you may notice that there are :0, this is created by tensorflow
_input = sess.graph.get_tensor_by_name("input:0")
_output = sess.graph.get_tensor_by_name("output:0")
# then we can run!
print(sess.run(_output, feed_dict={_input:3}))


INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
[81.]


# Frozen model and strip out the nodes irrelevant to inference¶

In [49]:
# The following two approaches are common flow for training and re-training the model
# but when we want to deploy the model we may want to frozen the model (convert nodes to constants)
# and strip out all the training related nodes (like optimizer, minimize, ...)

import tensorflow as tf

tf.reset_default_graph()

tf.train.import_meta_graph('./tmp/model.ckpt.meta')

saver = tf.train.Saver()
with tf.Session() as sess:
# like before we import meta graph and restore the weights first
saver.restore(sess, './tmp/model.ckpt')
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def,
['output_0'] # this is the output node list, here we only need the output_0
)
# we can write it to the output pb file
tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True)
tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False)
# now observe under the directory of ./tmp, there will be two files
# one is human readable text file and the other is binary file

INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
INFO:tensorflow:Froze 1 variables.
Converted 1 variables to const ops.


### notice that we don't do saver.save(sess, './tmp/model.ckpt')¶

Since we have stored the nodes as constants so there is no need to store the data to model.ckpt.data, right~

# Restore the model from the Protocol Buffer (PB) file¶

In [50]:
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
# we now want to load the graph_def from the pb file
with gfile.FastGFile('./tmp/output_binary.pb', 'rb') as f:
graph_def = tf.GraphDef()
tf.import_graph_def(graph_def, name='')

# get the input node
_input = sess.graph.get_tensor_by_name('input:0')
_output = sess.graph.get_tensor_by_name('output_0:0')
sess.run(tf.global_variables_initializer())
print (sess.run(_output, feed_dict={_input: 9}))
# if you try to access the output node as before you will get exception
output = sess.graph.get_tensor_by_name('output:0')
sess.run(output, feed)

[27.]

In [11]:
import tensorflow as tf

tf.reset_default_graph()

# case1 normal save and restore
# define simple graphs
a = tf.get_variable('a', [1,1], dtype=tf.float32)
b = tf.placeholder(tf.float32, shape=(1,1), name='input')
c = tf.matmul(a, b, name='output_0')
init = tf.global_variables_initializer()

# session will bind to the global default graph
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
print(sess.run(c, feed_dict={b:[[9]]}))
saver.save(sess, './tmp/model.ckpt')

# Frozen the graph
tf.reset_default_graph()
tf.train.import_meta_graph('./tmp/model.ckpt.meta')

saver = tf.train.Saver()
with tf.Session() as sess:
# like before we import meta graph and restore the weights first
saver.restore(sess, './tmp/model.ckpt')
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def,
['output_0'] # this is the output node list, here we only need the output_0
)
# we can write it to the output pb file
tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True)
tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False)

# Transform to TF-Lite
graph_def_file = "./tmp/output_binary.pb"
input_arrays = ["input"]
output_arrays = ["output_0"]

tf.reset_default_graph()

converter = tf.contrib.lite.TocoConverter.from_frozen_graph(graph_def_file,
input_arrays,
output_arrays)
tflite_model = converter.convert()
open("converterd_model.tflite", "wb").write(tflite_model)

[[-15.0248]]
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
INFO:tensorflow:Froze 1 variables.
INFO:tensorflow:Converted 1 variables to const ops.

Out[11]:
564