Tensorflow开发的基本步骤:
- 定义Tensorflow输入节点
- 通过占位符定义:
X = tf.placeholder("float")
2.通过字典类型定义:
inputdict = {
'x': tf.placeholder("float"),
'y': tf.placeholder("float")
}
3. 直接定义输入节点:
train_x = np.float32(np.linspace(-1,1,100))
- 定义“学习参数”的变量
- 定义“运算”
- 优化函数,优化目标
- 初始化所有变量
- 迭代更新参数到最优解
- 测试模型
- 使用模型
2、模型保存与载入
- 模型保存:
saver = tf.train.Saver() #生成saver
saverdir = "log/"
with tf.Session() as sess:
sess.run(init)
print("Finished")
saver.save(sess,saverdir+"linermodel.cpkt")
- 模型载入:
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
saver.restore(sess2,saverdir+"linermodel.cpkt")
print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))
检查点(Checkpoint):Tensorflow训练模型时难免会出现中断的情况,希望能够将辛苦得到的中间参数保留下来,在训练中保存模型,习惯上称之为保存检查点。
saver = tf.train.Saver(max_to_keep=1) #生成saver
saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))