首更:
由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe...好吧,caffe也没有这种python风格的设定...
废话少说,导入包:
import numpy as np
import tensorflow as tf
保存会话:
W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32)
b = tf.Variable([[1,2,3]],dtype=tf.float32) init = tf.global_variables_initializer()
saver = tf.train.Saver() # <--------- with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,'./my_net/saver_net.ckpt') # <---------
载入会话:
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32) saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'./my_net/saver_net.ckpt') # <---------
print('Weight:\n',sess.run(W))
print('biases:\n',sess.run(b))
输出如下:
Weight:
[[ 1. 2. 3.]
[ 4. 5. 6.]]
biases:
[[ 1. 2. 3.]]
载入会话会加载之前保存的变量,所以不需要tf.global_variables_initializer()激活本次变量了。
再更:
引入节点名称后,只要tf变量节点的名称一致,python变量名不一致也能完美继承,也就是说tf变量节点的名称识别权限大于python变量名
详细的命名规则下节有介绍:『TensorFlow』第八弹_变量与命名空间_固有结界
保存模型:
W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name='W') # <------
b = tf.Variable([[1,2,3]],dtype=tf.float32,name='b') # <------ init = tf.global_variables_initializer()
saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,'./my_net/saver_net.ckpt')
W--’W‘,b--’b‘
载入模型:
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32') # <------
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32') # <------ saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'./my_net/saver_net.ckpt')
print('Weight:\n',sess.run(W))
print('biases:\n',sess.run(b))
W,b
结果报错
载入模型:
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name='W') # <------
a = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name='b') # <------ saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'./my_net/saver_net.ckpt')
print('Weight:\n',sess.run(W))
print('biases:\n',sess.run(a))
W-’W‘,a--’b'
INFO:tensorflow:Restoring parameters from ./my_net/saver_net.ckpt
Weight:
[[ 1. 2. 3.]
[ 4. 5. 6.]]
biases:
[[ 1. 2. 3.]]