Tensorflow模型保存与载入的两种办法

                          TensorFlow模型保存与加载的两种方法

             1.saver.save/saver.restore:

                       这种方法保存的模型有四个文件:

                       Tensorflow模型保存与载入的两种办法

                      其中model.ckpt为模型的名称

                      1.checkpoint 文本文件、记录了模型文件的路径信息

                      2.model.ckpt.data-00000-of-00001 保存了模型的网络权重信息

                      3.model.ckpt.index 是二进制文件,保存了模型中的变量参数信息

                      4.model.ckpt.mete 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

                       关于protobuf的详解:https://www.jianshu.com/p/419efe983cb2,了解下大概是什么即可(菜鸡是这么认为的,要不过程太过耗费时间,大神勿喷)

                      使用saver进行模型保存:

import tensorflow as tf
from os.path import join as pjoin
import os
graph = tf.Graph()
with graph.as_default():
    X_1 = tf.placeholder(tf.float32,name= 'input_x1')
    X_2 = tf.placeholder(tf.float32,name='input_x2')
    b = tf.Variable(1.3,tf.float32,name='input_b')
    x1_mul_x2 = tf.add(X_1,X_2)
    add_x_b = tf.add(x1_mul_x2,b,name='add_op')

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

with tf.Session(graph=graph) as sess:
    sess.run(init)
    feed_dict = {X_1:1.2,X_2:3.1}
    y = sess.run(add_x_b,feed_dict=feed_dict)
    saver.save(sess,'model_save1/model.ckpt')

           程序会生成上述的四个文件。

          恢复模型:

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    saver = tf.train.import_meta_graph(pjoin(os.getcwd(),'model_save1/model.ckpt.meta')) #加载模型结构
    saver.restore(sess,tf.train.latest_checkpoint(pjoin(os.getcwd(),'model_save1')))  #指定模型保存目录恢复变量信息

    #获取保存的变量b
    print(sess.run('input_b:0'))
    
    #获取placeholder占位符变量
    input_x1 = sess.graph.get_tensor_by_name('input_x1:0')
    input_x2 = sess.graph.get_tensor_by_name('input_x2:0')

    #获取要计算的句柄Op
    op = sess.graph.get_tensor_by_name('add_op:0')

    #加入新的操作:
    add_new_op = tf.multiply(op,2)

    out = sess.run(add_new_op,feed_dict={input_x1:3,input_x2:4})
    print(out)

               2.builder/load:

             这种保存方式会生成pb格式的文件:

Tensorflow模型保存与载入的两种办法

Tensorflow模型保存与载入的两种办法

使用builder建立模型(例子是之前做的一个小deom,主要看保存过程):

     init = tf.global_variables_initializer()
        builder = tf.saved_model.builder.SavedModelBuilder(model_save_path)

        with tf.Session() as sess:
            writer = tf.summary.FileWriter("logs/", sess.graph)
            sess.run(init)
            
            for epoch in range(self.max_epochs):
                _,pred_train,loss_= sess.run([train_op,pred,loss],feed_dict = {
                    X : train_x,
                    Y : train_y,
                    keep_prob : 1,
                })
                if epoch % 100 == 0:
                    print('Number of iterations :' ,epoch,'loss',loss_) 

            X_TensorInfo = tf.saved_model.utils.build_tensor_info(X)
            Y_TensorInfo = tf.saved_model.utils.build_tensor_info(Y)
            keep_prob_TensorInfo = tf.saved_model.utils.build_tensor_info(keep_prob)

            prediction_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={'input_X' : X_TensorInfo,
                            'keep_prob' : keep_prob_TensorInfo},
                    outputs={'output' : Y_TensorInfo},
                    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
                ))
        
            legacy_init_op = tf.group(tf.tables_initializer(),name = 'legacy_init_op')
            builder.add_meta_graph_and_variables(sess,
                                                [tf.saved_model.tag_constants.SERVING],
                                                signature_def_map = {
                                                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature
                                                },
                                                legacy_init_op = legacy_init_op)
            builder.save()
            writer.close()

这个帖子对要是用的函数和参数介绍狠详细,大家可以参考下

https://blog.csdn.net/thriving_fcl/article/details/75213361  

然后就是模型恢复:

#建立会话对象,将模型恢复
        with tf.Session() as sess:
            MetaGraphDef = tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING],model_save_path)
            #解析得到SignatureDef protobuf
            SignatureDef_d = MetaGraphDef.signature_def
            SignatureDef = SignatureDef_d[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

            #解析得到变量对应的TensorInfo protobuf
            X_TensorInfo = SignatureDef.inputs['input_X']
            keep_prob_TensorInfo = SignatureDef.inputs['keep_prob']
            # Y_TensorInfo = SignatureDef.outputs['output']

            #解析得到具体的tensor
            input_X = sess.graph.get_tensor_by_name(X_TensorInfo.name)
            keep_prob = sess.graph.get_tensor_by_name(keep_prob_TensorInfo.name)
            pred = sess.graph.get_tensor_by_name('dnn_model/Add_1:0')
            
            # with tf.variable_scope('dnn_model',reuse = tf.AUTO_REUSE):
            #     pred = DNN(input_X)

            fre_predict = []
            for step in range(length):
                prediction = sess.run(pred,feed_dict = {input_X:np.array(pre_data[step]).reshape((1,2)),keep_prob:1.})
                prediction = prediction.reshape((-1))
                fre_predict.extend(prediction)

   pred = sess.graph.get_tensor_by_name('dnn_model/Add_1:0') 

   预测的变量,pred没有使用Y_TensroInfo来解析获得是因为使用这种方法,会出现placeholder占位符输入的错误,是因为在恢复模型后,使用模型进行预测的话,是不用placeholder的而是将pred作为返回值返回,所以会出现这个错误。

   由于时间比较紧张,具体到模型的部署,后续项目完成后,会出一个比较详细的过程。

上一篇:Tensorflow【实战Google深度学习框架】编程基础小漂亮总结


下一篇:golang中的Session支持