# %load /Users/facai/Study/book_notes/preconfig.py
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set(font='SimHei', font_scale=2.5)
plt.rcParams['axes.grid'] = False
import tensorflow as tf
def show_image(filename, figsize=None, res_dir=True):
if figsize:
plt.figure(figsize=figsize)
if res_dir:
filename = './res/{}'.format(filename)
plt.imshow(plt.imread(filename))
/usr/local/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters
参考: https://www.tensorflow.org/programmers_guide/graphs
tf.Graph
op, tensor
variable
name_scope, variable_scop, collection
save and restore
tf.Graph: GraphDef => *.pb文件
tf.Session():
with tf.Session("grpc://example.org:2222"):
pass
状态(Variable) => *.ckpt文件
a = tf.constant(1)
b = a * 2
b
<tf.Tensor 'mul_1:0' shape=() dtype=int32>
b.op
<tf.Operation 'mul_1' type=Mul>
b.consumers()
[]
a.op
<tf.Operation 'Const_1' type=Const>
a.consumers()
[<tf.Operation 'mul_1' type=Mul>]
tensorflow/python/framework/ops.py
__add__
b.op.outputs
[<tf.Tensor 'mul_1:0' shape=() dtype=int32>]
list(b.op.inputs)
[<tf.Tensor 'Const_1:0' shape=() dtype=int32>, <tf.Tensor 'mul_1/y:0' shape=() dtype=int32>]
print(b.op.inputs[0])
print(a)
Tensor("Const_1:0", shape=(), dtype=int32) Tensor("Const_1:0", shape=(), dtype=int32)
list(a.op.inputs)
[]
Operator和Tensor构成无向图
# run
sess.run([b])
参考:
v = tf.Variable([0])
c = b + v
c
<tf.Tensor 'add:0' shape=(1,) dtype=int32>
list(c.op.inputs)
[<tf.Tensor 'mul_1:0' shape=() dtype=int32>, <tf.Tensor 'Variable/read:0' shape=(1,) dtype=int32>]
c.op.inputs[1].op
<tf.Operation 'Variable/read' type=Identity>
list(c.op.inputs[1].op.inputs)
[<tf.Tensor 'Variable:0' shape=(1,) dtype=int32_ref>]
v
<tf.Variable 'Variable:0' shape=(1,) dtype=int32_ref>
实际上,对变量的读是通过tf.identity
算子得到:
c = tf.add(b, tf.identity(v))
参考:https://www.tensorflow.org/versions/master/api_docs/python/tf/Variable
===============
class Layer:
def build(self):
pass
def call(self, inputs):
pass
参考:https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard
graph_a = tf.Graph()
with graph_a.as_default():
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
print(v1)
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
inc_v1.op.run()
save_path = saver.save(sess, "./tmp/model.ckpt", write_meta_graph=True)
print("Model saved in path: %s" % save_path)
pb_path = tf.train.write_graph(graph_a.as_graph_def(), "./tmp/", "graph.pbtxt", as_text=True)
print("Graph saved in path: %s" % pb_path)
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> Model saved in path: ./tmp/model.ckpt Graph saved in path: ./tmp/graph.pbtxt
graph.pbtxt部份示意:v1 + 1
:
node {
name: "add"
op: "Add"
input: "v1/read"
input: "add/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
graph_b = tf.Graph()
with graph_b.as_default():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./tmp/model.ckpt.meta')
saver.restore(sess, "./tmp/model.ckpt")
print(graph_b.get_operations())
v1 = graph_b.get_tensor_by_name("v1:0")
print("------------------")
print("v1 : %s" % v1.eval(session=sess))
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt [<tf.Operation 'v1/Initializer/zeros' type=Const>, <tf.Operation 'v1' type=VariableV2>, <tf.Operation 'v1/Assign' type=Assign>, <tf.Operation 'v1/read' type=Identity>, <tf.Operation 'add/y' type=Const>, <tf.Operation 'add' type=Add>, <tf.Operation 'Assign' type=Assign>, <tf.Operation 'init' type=NoOp>, <tf.Operation 'save/Const' type=Const>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>] ------------------ v1 : [1. 1. 1.]
总结:
tf.train.Saver
会保存GraphDef和Variable信息,用它可以直接恢复图。tf.train.write_graph
、tf.GraphDef
和tf.import_graph_def
,主要用于固化模型(只有GraphDef信息)。参考: