tensorflow保存模型、恢复模型

1、模型训练(部分代码):

X = tf.placeholder(tf.float64,X_data.shape,name='X')
Y = tf.placeholder(tf.float64,Y_data.shape,name='Y')
epoch_num = 500
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loss_data = []
    # 创建FileWriter对象,用当前计算图初始化
    writer = tf.summary.FileWriter('./summary/', sess.graph)

    # 保存模型
    saver_path = './model/checkpoint/model.ckpt' # 模型保存路径
    saver = tf.train.Saver() # 新建Saver()对象

    for i in range(1,epoch_num+1):
        _, loss = sess.run([optimizer,loss_func],feed_dict={X:X_data,Y:Y_data})
        loss_data.append(loss)
        saved_path = saver.save(sess, saver_path) # 保存模型
        print("epoch:%d,loss:%.4g" % (i,loss))
    # 关闭FileWriter
    writer.close()

 

2、保存模型

# 模型保存路径
saver_path = './model/checkpoint/model.ckpt' 
# 新建Saver()对象
saver = tf.train.Saver()
# 保存模型
saved_path = saver.save(sess, saver_path)

执行之后,在目录./model/checkpoint/model.ckpt下,生成模型相关文件,如图:

tensorflow保存模型、恢复模型

 

3、恢复模型并使用模型、变量

meta_path = './model/checkpoint/model.ckpt.meta'
model_path = './model/checkpoint/model.ckpt'
# 导入计算图
saver = tf.train.import_meta_graph(meta_path)
config = tf.ConfigProto()
with tf.Session(config=config) as sess:
    # 恢复模型
    saver.restore(sess, model_path)
    # 此时默认图就是导入的图
    graph_restore = tf.get_default_graph()
    # 恢复变量
    W = graph_restore.get_tensor_by_name('W:0')
    b = graph_restore.get_tensor_by_name('b:0')
    # 预测模型
    predict_func = tf.matmul(test_data, W)
    predict_value = sess.run([predict_func],feed_dict={x:test_data})

 

 

 

上一篇:leetcode——169. 多数元素


下一篇:LeetCode——169. 多数元素