Tensorflow 获取model中的变量列表

1、动态获取 

(1)朴素获取法
       1) 朴素获取可训练变量:t_vars = tf.trainable_variables()
       2)朴素获取全部变量,包含声明training=False变量:all_vars = tf.global_variables()
(2)使用tensorflow.contrib.slim
       1) 获取常规变量(是slim里面与model变量对应的一个类型):regular_variables = slim.get_variables()
       2)直接获取:vars = slim.get_variables_to_restore()
       3)slim用于筛选方法
            a. 通过name筛选: variables = slim.get_variables_by_name("d_")
            b. 通过name后缀筛选:variables = slim.get_variables_by_suffix("_b")
            c. 通过namespace筛选:variables = slim.get_variables(scope="layer1")
            d. 通过include和exclude筛选
                d0. variables_to_restore = slim.get_variables_to_restore(include=["d_"])
                d1. variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])
(3) 离线获取(从一个已保存好的模型中获取var_list)
    1) 离线文件: checkpoint、model.data-xxxx、model.index、model.meta
    2) 将离线文件载入当前环境,变成动态获取
        

#记住,要先清空现有的图
#不然的话import_meta_graph会把原model里面的数据追加到现有的model中
#一片混乱
tf.reset_default_graph()
 
with tf.Session(graph=tf.get_default_graph()) as sess:
    new_saver = tf.train.import_meta_graph('e:/mytrain/results/20190227_01/model/model.meta')
    new_saver.restore(sess, 'e:/mytrain/results/20190227_01/model/model')
    #加载进来之后还不是为所欲为
    var_list=tf.global_variables()
        


    3) 直接从离线文件获取
        

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
         
#文件夹地址改成自己的
model_dir="'e:\\20190227_01\\mytrain\\results\\20190227_01\\model"
         
ckpt = tf.train.get_checkpoint_state(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
         
#返回一个dict= {'name':[shape] }
#例如 'd_w2/Adam':[4, 4, 32, 64]
var_to_shape_map = reader.get_variable_to_shape_map()
         
#我们可以用遍历的方式,取出字典里所有的key
for key in var_to_shape_map:
    print(key)        #key是str类型的
    #再用key去找这个tensor的值
    a=reader.get_tensor(key)
    print(type(a))    #输出: <class 'numpy.ndarray'>

上一篇:如何快速构建Slim Docker映像


下一篇:使用camera在tensorflow/slim下调用pb文件进行图像识别的预测