两种从 TensorFlow 的 checkpoint生成 frozenpb 的方法

1. 从 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util


def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "outputs"
    saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True)
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  
            # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) 
        #得到当前图有几个操作节点

if __name__ == "__main__":
    # 输入ckpt模型路径
    input_checkpoint='ckpt_path/ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="some_path/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

2. 从网络代码和 ckpt-.data 文件生成 frozenpb

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph

import network  # 导入网络结构

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 设置GPU
model_path = "ckpt_path/ckpt-10000"

def main():
    tf.reset_default_graph()
    input_node = tf.placeholder(
        tf.float32, shape=(None,112, 96, 3)
    ) 
    input_node = tf.identity(input_node,name="inputs") # 设置输入节点的名字,这里可以自定义名称
    flow = network(input_node)
    flow = tf.identity(flow, name="outs") # 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_path)
        # 保存图
        tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
        # 把图和参数结构一起
        freeze_graph.freeze_graph(
            "logdir/graph.pb", # 上面保存的图结构 graph.pb
            "",
            False,
            model_path,
            "outs",
            "save/restore_all", # 默认恢复所有
            "save/Const:0", # 默认常量
            "some_path/frozen.pb", # 保存frozen.pb
            False,
            "",
        )
    print("done")


if __name__ == "__main__":
    main()

3. 打印 网络中节点的名字

import tensorflow as tf


if __name__ == "__main__":
    checkpoint_path = '../model_fintune/ckpt-1400'  
    reader = tf.train.NewCheckpointReader(checkpoint_path)  
    var_to_shape_map = reader.get_variable_to_shape_map()  
    
    for key in var_to_shape_map:  
        print("tensor name: ", key)  
        # print(reader.get_tensor(key))

或者通过

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)

printTensors("path-to-my-pbfile.pb")

4. 两种方法对比

如果是自己的代码训练的模型,有网络结构,有 ckpt 文件,最好是使用第二种方法,使用起来很灵活,可以进行各种自定义,比如修改输入输出的节点名字,网络有多个路径的时候可以自定义输出路径。第一种方法,应该也能达到第二种方法的效果,因为它们本来就是等价的,可能会有些麻烦。第一种方法的好处就是快,不要去翻那些杂糅在一起的网络结构。

上一篇:[Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model


下一篇:VC++内存泄漏检测方法(2):Checkpoint/DumpStatistics