[tf1] 保存和加载参数

tf.keras

参考 https://github.com/tensorflow/docs/blob/529ba4346b8fc5e830e762a2f0ee87b3c345c0c9/site/en/r1/guide/keras.ipynb

# Save weights to a TensorFlow Checkpoint file
model.save_weights('./weights/my_model')

# Restore the model's state,
# this requires a model with the same architecture.
model.load_weights('./weights/my_model')

也可以保存为 Keras HDF5 格式

# Save weights to a HDF5 file
model.save_weights('my_model.h5', save_format='h5')

# Restore the model's state
model.load_weights('my_model.h5')

tf.train.Saver

可以保存指定参数。参考 https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver

self._saver = tf.train.Saver(var_list=self._get_var_list(), \
				max_to_keep=self.max_tf_checkpoints_to_keep)
self._saver.save(
	self._sess,
	os.path.join(checkpoint_dir, 'tf_ckpt'),
	global_step=self.iteration)
self._saver.restore(self._sess,
                    os.path.join(checkpoint_dir,
                                 'tf_ckpt-{}'.format(iteration_number)))

tf.train.Checkpoint

不懂。参考 https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/Checkpoint

上一篇:VUE笔记


下一篇:【随手记】pytorch