TensorFlow模型保存与加载的两种方法
1.saver.save/saver.restore:
这种方法保存的模型有四个文件:
其中model.ckpt为模型的名称
1.checkpoint 文本文件、记录了模型文件的路径信息
2.model.ckpt.data-00000-of-00001 保存了模型的网络权重信息
3.model.ckpt.index 是二进制文件,保存了模型中的变量参数信息
4.model.ckpt.mete 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf
关于protobuf的详解:https://www.jianshu.com/p/419efe983cb2,了解下大概是什么即可(菜鸡是这么认为的,要不过程太过耗费时间,大神勿喷)
使用saver进行模型保存:
import tensorflow as tf
from os.path import join as pjoin
import os
graph = tf.Graph()
with graph.as_default():
X_1 = tf.placeholder(tf.float32,name= 'input_x1')
X_2 = tf.placeholder(tf.float32,name='input_x2')
b = tf.Variable(1.3,tf.float32,name='input_b')
x1_mul_x2 = tf.add(X_1,X_2)
add_x_b = tf.add(x1_mul_x2,b,name='add_op')
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session(graph=graph) as sess:
sess.run(init)
feed_dict = {X_1:1.2,X_2:3.1}
y = sess.run(add_x_b,feed_dict=feed_dict)
saver.save(sess,'model_save1/model.ckpt')
程序会生成上述的四个文件。
恢复模型:
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
saver = tf.train.import_meta_graph(pjoin(os.getcwd(),'model_save1/model.ckpt.meta')) #加载模型结构
saver.restore(sess,tf.train.latest_checkpoint(pjoin(os.getcwd(),'model_save1'))) #指定模型保存目录恢复变量信息
#获取保存的变量b
print(sess.run('input_b:0'))
#获取placeholder占位符变量
input_x1 = sess.graph.get_tensor_by_name('input_x1:0')
input_x2 = sess.graph.get_tensor_by_name('input_x2:0')
#获取要计算的句柄Op
op = sess.graph.get_tensor_by_name('add_op:0')
#加入新的操作:
add_new_op = tf.multiply(op,2)
out = sess.run(add_new_op,feed_dict={input_x1:3,input_x2:4})
print(out)
2.builder/load:
这种保存方式会生成pb格式的文件:
使用builder建立模型(例子是之前做的一个小deom,主要看保存过程):
init = tf.global_variables_initializer()
builder = tf.saved_model.builder.SavedModelBuilder(model_save_path)
with tf.Session() as sess:
writer = tf.summary.FileWriter("logs/", sess.graph)
sess.run(init)
for epoch in range(self.max_epochs):
_,pred_train,loss_= sess.run([train_op,pred,loss],feed_dict = {
X : train_x,
Y : train_y,
keep_prob : 1,
})
if epoch % 100 == 0:
print('Number of iterations :' ,epoch,'loss',loss_)
X_TensorInfo = tf.saved_model.utils.build_tensor_info(X)
Y_TensorInfo = tf.saved_model.utils.build_tensor_info(Y)
keep_prob_TensorInfo = tf.saved_model.utils.build_tensor_info(keep_prob)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input_X' : X_TensorInfo,
'keep_prob' : keep_prob_TensorInfo},
outputs={'output' : Y_TensorInfo},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
))
legacy_init_op = tf.group(tf.tables_initializer(),name = 'legacy_init_op')
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature
},
legacy_init_op = legacy_init_op)
builder.save()
writer.close()
这个帖子对要是用的函数和参数介绍狠详细,大家可以参考下
https://blog.csdn.net/thriving_fcl/article/details/75213361
然后就是模型恢复:
#建立会话对象,将模型恢复
with tf.Session() as sess:
MetaGraphDef = tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING],model_save_path)
#解析得到SignatureDef protobuf
SignatureDef_d = MetaGraphDef.signature_def
SignatureDef = SignatureDef_d[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
#解析得到变量对应的TensorInfo protobuf
X_TensorInfo = SignatureDef.inputs['input_X']
keep_prob_TensorInfo = SignatureDef.inputs['keep_prob']
# Y_TensorInfo = SignatureDef.outputs['output']
#解析得到具体的tensor
input_X = sess.graph.get_tensor_by_name(X_TensorInfo.name)
keep_prob = sess.graph.get_tensor_by_name(keep_prob_TensorInfo.name)
pred = sess.graph.get_tensor_by_name('dnn_model/Add_1:0')
# with tf.variable_scope('dnn_model',reuse = tf.AUTO_REUSE):
# pred = DNN(input_X)
fre_predict = []
for step in range(length):
prediction = sess.run(pred,feed_dict = {input_X:np.array(pre_data[step]).reshape((1,2)),keep_prob:1.})
prediction = prediction.reshape((-1))
fre_predict.extend(prediction)
pred = sess.graph.get_tensor_by_name('dnn_model/Add_1:0')
预测的变量,pred没有使用Y_TensroInfo来解析获得是因为使用这种方法,会出现placeholder占位符输入的错误,是因为在恢复模型后,使用模型进行预测的话,是不用placeholder的而是将pred作为返回值返回,所以会出现这个错误。
由于时间比较紧张,具体到模型的部署,后续项目完成后,会出一个比较详细的过程。