tf.keras
# Save weights to a TensorFlow Checkpoint file
model.save_weights('./weights/my_model')
# Restore the model's state,
# this requires a model with the same architecture.
model.load_weights('./weights/my_model')
也可以保存为 Keras HDF5 格式
# Save weights to a HDF5 file
model.save_weights('my_model.h5', save_format='h5')
# Restore the model's state
model.load_weights('my_model.h5')
tf.train.Saver
可以保存指定参数。参考 https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver
self._saver = tf.train.Saver(var_list=self._get_var_list(), \
max_to_keep=self.max_tf_checkpoints_to_keep)
self._saver.save(
self._sess,
os.path.join(checkpoint_dir, 'tf_ckpt'),
global_step=self.iteration)
self._saver.restore(self._sess,
os.path.join(checkpoint_dir,
'tf_ckpt-{}'.format(iteration_number)))
tf.train.Checkpoint
不懂。参考 https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/Checkpoint