#!/usr/bin/env python # coding: utf-8 # In[17]: import tensorflow as tf tf.reset_default_graph() # define simple graphs a = tf.Variable([3.], dtype=tf.float32, name='a') b = tf.placeholder(tf.float32, shape=(), name='input') c = tf.multiply(a, b) d = tf.multiply(c, c, name='output') init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(sess.run(d, feed_dict={b:2})) # # The following define the basic graph and use a session to compute the value # In[35]: import tensorflow as tf tf.reset_default_graph() # case1 normal save and restore # define simple graphs a = tf.Variable([3.], dtype=tf.float32, name='a') b = tf.placeholder(tf.float32, shape=(), name='input') c = tf.multiply(a, b, name='output_0') d = tf.multiply(c, c, name='output') init = tf.global_variables_initializer() # session will bind to the global default graph saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) print(sess.run(d, feed_dict={b:2})) saver.save(sess, './tmp/model.ckpt') # then under the directory ./tmp you will find the two files # model.ckpt.meta : The definition of graph # model.ckpt.data-00000-of-00001 : The data (the value for the nodes) # # The first way to restore the model (Defined the new set of variables) # In[36]: # Then we have two ways to restore the graph tf.reset_default_graph() # This cell show the first case is we define a complete same graph with the original one # and load the model.ckpt.data's value back to this new graph # the name: 'a', 'input', and 'output' should be consistent with original graph j = tf.Variable([3.], dtype=tf.float32, name='a') k = tf.placeholder(tf.float32, shape=(), name='input') l = tf.multiply(j, k) m = tf.multiply(l, l, name='output') init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) saver.restore(sess, './tmp/model.ckpt') # ok now we can test with the new samples to the placeholder! # the answer for the following should be: (3*3) * (3*3) print(sess.run(m, feed_dict={k:3})) # You may found a little trivial for this example, but this method is useful for re-training the same # graph. Since usually we'll define a inference model as a function, we don't need to manually create # every nodes on as in this example. # something like: # logits = inference(X) # sess.run(logits, feed_dict={X: 123}) # # The second way to restore the model (Load the graph by DEF) # In[37]: tf.reset_default_graph() # This cell we want to introduce another way for restoring the model # This method we don't want to define a set of new nodes(j, k, l, m) but just load # graph defined before into the current session and restore the data back to the loading # nodes of graph # we import the meta file which is the graph for the model tf.train.import_meta_graph('./tmp/model.ckpt.meta') saver = tf.train.Saver() with tf.Session() as sess: # now we want restore all the values to the graph saver.restore(sess, './tmp/model.ckpt') # Before we use sess.run, we need to got the output tensor # and input tensors # you may notice that there are :0, this is created by tensorflow _input = sess.graph.get_tensor_by_name("input:0") _output = sess.graph.get_tensor_by_name("output:0") # then we can run! print(sess.run(_output, feed_dict={_input:3})) # # Frozen model and strip out the nodes irrelevant to inference # In[49]: # The following two approaches are common flow for training and re-training the model # but when we want to deploy the model we may want to frozen the model (convert nodes to constants) # and strip out all the training related nodes (like optimizer, minimize, ...) import tensorflow as tf tf.reset_default_graph() tf.train.import_meta_graph('./tmp/model.ckpt.meta') saver = tf.train.Saver() with tf.Session() as sess: # like before we import meta graph and restore the weights first saver.restore(sess, './tmp/model.ckpt') output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['output_0'] # this is the output node list, here we only need the output_0 ) # we can write it to the output pb file tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True) tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False) # now observe under the directory of ./tmp, there will be two files # one is human readable text file and the other is binary file # ### notice that we don't do saver.save(sess, './tmp/model.ckpt') # Since we have stored the nodes as constants so there is no need to store the data to model.ckpt.data, right~ # # Restore the model from the Protocol Buffer (PB) file # In[50]: import tensorflow as tf from tensorflow.python.platform import gfile with tf.Session() as sess: # we now want to load the graph_def from the pb file with gfile.FastGFile('./tmp/output_binary.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') # get the input node _input = sess.graph.get_tensor_by_name('input:0') _output = sess.graph.get_tensor_by_name('output_0:0') sess.run(tf.global_variables_initializer()) print (sess.run(_output, feed_dict={_input: 9})) # if you try to access the output node as before you will get exception output = sess.graph.get_tensor_by_name('output:0') sess.run(output, feed) # In[11]: import tensorflow as tf tf.reset_default_graph() # case1 normal save and restore # define simple graphs a = tf.get_variable('a', [1,1], dtype=tf.float32) b = tf.placeholder(tf.float32, shape=(1,1), name='input') c = tf.matmul(a, b, name='output_0') init = tf.global_variables_initializer() # session will bind to the global default graph saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) print(sess.run(c, feed_dict={b:[[9]]})) saver.save(sess, './tmp/model.ckpt') # Frozen the graph tf.reset_default_graph() tf.train.import_meta_graph('./tmp/model.ckpt.meta') saver = tf.train.Saver() with tf.Session() as sess: # like before we import meta graph and restore the weights first saver.restore(sess, './tmp/model.ckpt') output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['output_0'] # this is the output node list, here we only need the output_0 ) # we can write it to the output pb file tf.train.write_graph(output_graph_def, "./tmp", "output_text.pb", True) tf.train.write_graph(output_graph_def, "./tmp", "output_binary.pb", False) # Transform to TF-Lite graph_def_file = "./tmp/output_binary.pb" input_arrays = ["input"] output_arrays = ["output_0"] tf.reset_default_graph() converter = tf.contrib.lite.TocoConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converterd_model.tflite", "wb").write(tflite_model)