TensorFlow -- 模型保存与读取

最近学习Google的深度学习框架TensorFlow,CNN模型训练什么的都是OK的,官方也有代码,中文详解请参照:
http://www.soaringroad.com/?p=115

但是在实际使用的时候不可能每次预测都训练一遍模型,这样太浪费时间,所以需要我们在训练完成的时候保存模型,并且在需要预测的时候加载。官方提供的例子和解释不够具体,让我踩了很多的坑,所以写个笔记分享一下,希望帮助大家跳过或者少踩这些坑。

模型保存:

①首先对于需要保存的变量进行定义,记得variable和placeholder保存用变量名的定义一定不能忘了

定义的形式大体上如下:

var_name_1= tf.Variable(........,name='var_name_1_store')
var_name_2= tf.argmax(var_name_1,name='var_name_2_store')
var_name_3=tf.placeholder(........,name='var_name_3_store')
var_name_4=tf.matmul(var_name_1,var_name_3,name='var_name_4_store')

②其次就是保存处理

需要利用 tf.train.Saver来保存模型,其中global_step不定义的情况下,默认为0

saver = tf.train.Saver()
saver.save(sess,'./data.chkp',global_step=XX)

模型加载:

①首先读取刚刚保存的meta文件,然后全局变量初始化,需要用到tf.train.import_meta_graph

saver = tf.train.import_meta_graph("./data.chkp.meta")
sess.run(tf.global_variables_initializer())

②其次加载我们需要的变量,并预测,这里用到var_name_3_store,这就是为什么前面placeholder定义的时候一定要定义name

 predict = tf.get_default_graph().get_tensor_by_name("var_name_4_store:0")
 predict.eval(feed_dist={'var_name_3_store':XXXXX})

 

上一篇:Java-学习笔记-4-抽象


下一篇:Java-学习笔记-3-环境搭建