【转载】 tensorflow batch_normalization的正确使用姿势

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/computerme/article/details/80836060

 


————————————————
 

 

 

 

BN在如今的CNN结果中已经普遍应用,在tensorflow中可以通过tf.layers.batch_normalization()这个op来使用BN。该op隐藏了对BN的mean var alpha beta参数的显示声明,因此在训练和部署测试中需要特征注意正确使用BN的姿势。

 

 

 

 

###正确使用BN训练

 

注意把tf.layers.batch_normalization(x, training=is_training,name=scope)输入参数的training=True。另外需要在来训练中添加update_ops以便在每一次训练完后及时更新BN的参数。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
with tf.control_dependencies(update_ops): #保证train_op在update_ops执行之后再执行。  
   train_op = optimizer.minimize(loss) 

 

 

 

 

 

###正确保存带BN的模型
保存模型的时候不能只保存trainable_variables因为BN的参数不属于trainable_variables。为了方便,可以用tf.global_variables()。使用姿势如下

saver = tf.train.Saver(var_list=tf.global_variables())
savepath = saver.save(sess, 'here_is_your_personal_model_path’)

 

 

 

 

###正确读取带BN的模型
与保存类似,读的时候变量也需要为 global_variables 。如下:

saver = tf.train.Saver()
or saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'here_is_your_personal_model_path')

 

 

 

 

 

PS:inference的时候还需要把tf.layers.batch_normalization(x, training=is_training,name=scope) 这里的training设为False

 

 

 

 

 

 

 

Reference:
https://*.com/questions/48260394/whats-the-differences-between-tf-graphkeys-trainable-variables-and-tf-graphkeys

 

上一篇:python学习教程:tensorflow实现训练变量checkpoint的保存与读取


下一篇:机器学习-模型保存和加载