TensorFlow学习记录:saved_model模块的用法

TensorFlow中的saved_model模块用于生成冻结图文件,并且saved_model模块封装了平常用的Saver类。与Saver类不同的是,saved_model模块生成的模型文件集成了打标签的操作,可以更方便地部署在生产环境中。

关于为什么要用saved_model模块,这篇文章讲得挺好的。请点击这里

一个saved_model对象可以存储一个或多个MetaGraphDef。那什么时候需要多个MetaGraphDef呢?也许你想同时保存模型的CPU版本和GPU版本,或者你想同时保存模型的开发版本和生产版本。这个时候你就可以用tag(标签)来区分它们了。在加载模型的时候能根据tag标签来加载不同的MetaGraphDef。

TensorFlow中的saved_model模块可以给MetaGraphDef添加多个签名(signature)。每个签名的的结构都由输入节点、输出节点、名字3部分组成。并且,输入节点,输出节点的名字可以任意指定。

1.导出带有签名的模型文件

假设之前训练了一个模型,让模型在一组混乱的数据中找到y≈2x的规律。其中
(1)用saved_model模块的builder.SavedModelBuilder类实例化一个builder对象。
(2)构建签名的输入节点inputs。该输入节点的名字为“input_x”。该名字是模型文件中输入节点的名字(可以任意取名)。
(3)构建标签的输出节点outputs。该输出节点的名字为“output”。
(4)调用build_signature_def函数,并将输入节点、输出节点和名字(sig_name)传入,生成一个签名对象。
(5)用builder对象的add_meta_graph_and_variables方法将签名加入到模型中。
(6)调用builder对象的save方法导出带有签名的模型文件。

代码如下:

from tensorflow.python.saved_model import tag_constants
	#saveddir+"tfservingmodel"为模型的保存路径
    builder = tf.saved_model.builder.SavedModelBuilder(savedir+'tfservingmodel')
    
    #定义输入签名,X为输入tensor
    inputs = {'input_x': tf.saved_model.utils.build_tensor_info(X)}
    #定义输出签名, z为最终需要的输出结果tensor 
    outputs = {'output' : tf.saved_model.utils.build_tensor_info(z)}
    #调用build_signature_def()函数,并将输入节点、输出节点和名字(sig_name)传入,生成具体的签名对象
    signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'sig_name')
    
    #将节点的定义和值加到builder中,同时还加入了tag标签(tag_constants.SERVING), 还可以使用TRAINING、GPU或自定义
    builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], {'my_signature':signature})
    builder.save()     

运行后,会生成如下图所示文件
TensorFlow学习记录:saved_model模块的用法
其中variables文件里的内容如下所示

TensorFlow学习记录:saved_model模块的用法
从第一张图可以看出,tfservingmodel文件夹包含了一个文件和一个文件夹,文件save_model.pb是模型的定义文件,文件夹variables中放置了具体的模型文件。
从第二张图可以看出,variables文件夹包含了两个模型文件,variables.data-00000-of-00001文件保存了模型中参数的值,variables.index文件保存了模型中节点符号的定义。

我们可以看下saved_model.pb文件中保存的张量名字和属性

import tensorflow  as tf
from tensorflow import saved_model as sm

model_path = "log/tfservingmodel"

with tf.Session()as sess:
	meta_graph_def = tf.saved_model.loader.load(sess,[sm.tag_constants.SERVING],model_path)
	
	op_list = sess.graph.get_operations()   #load完后可以直接从sess.graph中获取所有节点
	with open("operations.txt",'a+')as f:
		for index,op in enumerate(op_list):
			f.write(str(op.name)+"\n")
			f.write(str(op.values())+"\n")

运行结果截图(部分):
TensorFlow学习记录:saved_model模块的用法

2.根据tag标签导入模型文件,并根据签名找到网络节点

导入刚刚保存的模型
(1)用saved_model模块中的loader.load方法根据tag标签导入对应的模型文件。
(2)用signature_def方法从导入的模型中提取签名。
(3)以字典取值的方式取出输入、输出节点。
(4)向模型注入数据,并输出结果。

代码如下:

from tensorflow.python.saved_model import tag_constants
with tf.Session() as sess:
	#根据tag_constants.SERVING标签找到对应的计算图
    meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], savedir+'tfservingmodel')
    # 从meta_graph_def中取出SignatureDef对象
    signature = meta_graph_def.signature_def
    
    # 从signature中找出具体输入输出的tensor name 
    x = signature['my_signature'].inputs['input_x'].name
    result = signature['my_signature'].outputs['output'].name

    y = sess.run(result, feed_dict={x: 5})#传入5,进行预测
    print(y)       

运行结果:
TensorFlow学习记录:saved_model模块的用法

3.用saved_model_cli工具查看及使用saved_model模型

在命令行中,用saved_model_cli工具查看和使用生成的saved_model模型。具体内容如下 :
(1)找出tag标签对应的MetaGraphDef。
(2)找出MetaGraphDef中的signature、输入、输出节点等相关信息。
(3)以命令行的方式向模型输入数据,使其运行并输出结果。

saved_model_cli工具工具共有两个主要的参数:

  • show参数:侧重于查看模型中的信息。
  • run参数:侧重于运行模型。

3.1查看模型文件中的tag(标签)

saved_model_cli show --dir log/tfservingmodel

运行结果:
TensorFlow学习记录:saved_model模块的用法
我们可以看到输出结果为serve,表明SavedModel对象里面只有一个MetaGraphDef,这个serve对应于tag_constants.SERVING。

3.2查看serve对应的MetaGraphDef中的签名

saved_model_cli show --dir log/tfservingmodel --tag_set serve

运行结果:
TensorFlow学习记录:saved_model模块的用法
我们可以看到输出结果为SignatureDef Key:“my_signature”,表明serve对应的MetaGraphDef中有一个签名为"my_signature",与上面1中生成带有签名时的一致。

3.3查看signature中定义的输入、输出节点的名称

saved_model_cli show --dir log/tfservingmodel --tag_set serve --signature_def my_signature

运行结果:
TensorFlow学习记录:saved_model模块的用法
我们可以看到,模型的输入节点的张量为input_x,输出节点的张量为output。

上面的内容可以用saved_model_cli工具中的“–all”参数查看模型文件中的全部信息。

saved_model_cli show --dir log/tfservingmodel --all

运行结果:
TensorFlow学习记录:saved_model模块的用法

4.用run参数运行模型

用saved_model_cli 工具的run参数时,需要先指定好模型的路径、tag(标签)及signature(签名),再往模型里面输入数据,并运行。
在输入数据部分,可以用参数来指定不同的输入方式。

  • —inputs:后面跟具体的文件。文件类型支持numpy文件(npy、npz)和pickle文件(plk)。
  • —input_exprs:指定某个变量,向模型注入数据。
  • —input_examples:用字典方式向模型注入数据。

以“–input_exprs”为例,具体命令如下:

saved_model_cli run --dir log/tfservingmodel --tag_set serve --signature_def my_signature --input_exprs"input_x=4.2"

运行结果:
TensorFlow学习记录:saved_model模块的用法

参考书籍:《深度学习之TensorFlow工程化项目实战》 李金洪 编著

上一篇:为Flickr身份验证创建签名(Android SDK)


下一篇:Java 添加、验证PDF 数字签名