In [17]:
import tensorflow as tf

tf.reset_default_graph()

# define simple graphs
a = tf.Variable([3.], dtype=tf.float32, name='a')
b = tf.placeholder(tf.float32, shape=(), name='input')
c = tf.multiply(a, b)
d = tf.multiply(c, c, name='output')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(d, feed_dict={b:2}))
[36.]

The following define the basic graph and use a session to compute the value

In [35]:
import tensorflow as tf

tf.reset_default_graph()

# case1 normal save and restore
# define simple graphs
a = tf.Variable([3.], dtype=tf.float32, name='a')
b = tf.placeholder(tf.float32, shape=(), name='input')
c = tf.multiply(a, b, name='output_0')
d = tf.multiply(c, c, name='output')
init = tf.global_variables_initializer()

# session will bind to the global default graph
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(d, feed_dict={b:2}))
    saver.save(sess, './tmp/model.ckpt')
# then under the directory ./tmp you will find the two files
# model.ckpt.meta : The definition of graph
# model.ckpt.data-00000-of-00001 : The data (the value for the nodes)
[36.]

The first way to restore the model (Defined the new set of variables)

In [36]:
# Then we have two ways to restore the graph

tf.reset_default_graph()

# This cell show the first case is we define a complete same graph with the original one
# and load the model.ckpt.data's value back to this new graph
# the name: 'a', 'input', and 'output' should be consistent with original graph
j = tf.Variable([3.], dtype=tf.float32, name='a')
k = tf.placeholder(tf.float32, shape=(), name='input')
l = tf.multiply(j, k)
m = tf.multiply(l, l, name='output')
init = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, './tmp/model.ckpt')
    # ok now we can test with the new samples to the placeholder!
    # the answer for the following should be: (3*3) * (3*3)
    print(sess.run(m, feed_dict={k:3}))
    
# You may found a little trivial for this example, but this method is useful for re-training the same
# graph. Since usually we'll define a inference model as a function, we don't need to manually create
# every nodes on as in this example.

# something like:
# logits = inference(X)
# sess.run(logits, feed_dict={X: 123})
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
[81.]

The second way to restore the model (Load the graph by DEF)

In [37]:
tf.reset_default_graph()

# This cell we want to introduce another way for restoring the model
# This method we don't want to define a set of new nodes(j, k, l, m) but just load
# graph defined before into the current session and restore the data back to the loading
# nodes of graph

# we import the meta file which is the graph for the model
tf.train.import_meta_graph('./tmp/model.ckpt.meta')
saver = tf.train.Saver()
with tf.Session() as sess:
    # now we want restore all the values to the graph
    saver.restore(sess, './tmp/model.ckpt')
    # Before we use sess.run, we need to got the output tensor
    # and input tensors
    
    # you may notice that there are :0, this is created by tensorflow
    _input = sess.graph.get_tensor_by_name("input:0")
    _output = sess.graph.get_tensor_by_name("output:0")
    # then we can run!
    print(sess.run(_output, feed_dict={_input:3}))
    
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
[81.]

Frozen model and strip out the nodes irrelevant to inference

In [49]:
# The following two approaches are common flow for training and re-training the model
# but when we want to deploy the model we may want to frozen the model (convert nodes to constants)
# and strip out all the training related nodes (like optimizer, minimize, ...)

import tensorflow as tf

tf.reset_default_graph()

tf.train.import_meta_graph('./tmp/model.ckpt.meta')

saver = tf.train.Saver()
with tf.Session() as sess:
    # like before we import meta graph and restore the weights first
    saver.restore(sess, './tmp/model.ckpt')
    output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def,
            ['output_0'] # this is the output node list, here we only need the output_0
    )
    # we can write it to the output pb file
    tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True)
    tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False)
# now observe under the directory of ./tmp, there will be two files
# one is human readable text file and the other is binary file
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
INFO:tensorflow:Froze 1 variables.
Converted 1 variables to const ops.

notice that we don't do saver.save(sess, './tmp/model.ckpt')

Since we have stored the nodes as constants so there is no need to store the data to model.ckpt.data, right~

Restore the model from the Protocol Buffer (PB) file

In [50]:
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    # we now want to load the graph_def from the pb file
    with gfile.FastGFile('./tmp/output_binary.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        
        # get the input node
        _input = sess.graph.get_tensor_by_name('input:0')
        _output = sess.graph.get_tensor_by_name('output_0:0')
        sess.run(tf.global_variables_initializer())
        print (sess.run(_output, feed_dict={_input: 9}))
        # if you try to access the output node as before you will get exception
        output = sess.graph.get_tensor_by_name('output:0')
        sess.run(output, feed)
[27.]
In [11]:
import tensorflow as tf

tf.reset_default_graph()

# case1 normal save and restore
# define simple graphs
a = tf.get_variable('a', [1,1], dtype=tf.float32)
b = tf.placeholder(tf.float32, shape=(1,1), name='input')
c = tf.matmul(a, b, name='output_0')
init = tf.global_variables_initializer()

# session will bind to the global default graph
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c, feed_dict={b:[[9]]}))
    saver.save(sess, './tmp/model.ckpt')

# Frozen the graph    
tf.reset_default_graph()
tf.train.import_meta_graph('./tmp/model.ckpt.meta')

saver = tf.train.Saver()
with tf.Session() as sess:
    # like before we import meta graph and restore the weights first
    saver.restore(sess, './tmp/model.ckpt')
    output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def,
            ['output_0'] # this is the output node list, here we only need the output_0
    )
    # we can write it to the output pb file
    tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True)
    tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False)    

# Transform to TF-Lite
graph_def_file = "./tmp/output_binary.pb"
input_arrays = ["input"]
output_arrays = ["output_0"]

tf.reset_default_graph()

converter = tf.contrib.lite.TocoConverter.from_frozen_graph(graph_def_file,
                                                            input_arrays,
                                                            output_arrays)
tflite_model = converter.convert()
open("converterd_model.tflite", "wb").write(tflite_model)
[[-15.0248]]
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
INFO:tensorflow:Froze 1 variables.
INFO:tensorflow:Converted 1 variables to const ops.
Out[11]:
564