tensorflow保存和加载训练进度

tf.train.Checkpoint : 用于创建checkpoint

tf.train.Checkpoint:

tf.train.Checkpoint(
    root=None, **kwargs
)

TensorFlow objects may contain trackable state, such as tf.Variables, tf.keras.optimizers.Optimizer implementations, tf.data.Dataset iteratorstf.keras.Layer implementations, or tf.keras.Model implementations. These are called trackable objects.

# 例子
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)  # 关键字参数名字可以自己取, 但是save和load时必须要一致
checkpoint.restore(save_path)

属性字段:

Attributes
save_counter Incremented when save() is called. Used to number checkpoints.

成员方法:

read

read(
    save_path, options=None
) # 和restore基本一致, 但是不会恢复save_counter属性的值

restore

restore(
    save_path, options=None
) # 如果对象已经存在会立即赋值, 如果对象(Variables)还不存在会延时到对象创建时赋值
# 如果要确保加载的时候就已经赋值完成, 后续不会再赋值, 加个断言assert_consumed()
# 由于保存keras模型时会保存很多新的keys,加载时由于有很多未使用的key,就会有大量warning, 使用expect_partial()来关闭这些warning
checkpoint.restore(path).assert_consumed()  # 确保赋值完成
checkpoint.restore(path).expect_partial() # 关闭对未使用对象的warning

The returned status object has the following methods:

  • assert_consumed(): Raises an exception if any variables are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with other assertions.
  • assert_existing_objects_matched(): Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike assert_consumed, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a tf.keras.Layer object which has not yet been built, and so has not created any variables, will pass this assertion but fail assert_consumed. Useful when loading part of a larger checkpoint into a new Python program, e.g. a training checkpoint with a tf.compat.v1.train.Optimizer was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with other assertions.
  • assert_nontrivial_match(): Asserts that something aside from the root object was matched. This is a very weak assertion, but is useful for sanity checking in library code where objects may exist in the checkpoint which haven't been created in Python and some Python objects may not have a checkpointed value.
  • expect_partial()Silence warnings about incomplete checkpoint restores. Warnings are otherwise printed for unused parts of the checkpoint file or object when the Checkpoint object is deleted (often at program shutdown).

save

save(
    file_prefix, options=None
)
# checkpoint.save("/tmp/ckpt"). filre_prefix是checkpoints保存的根目录

write

write(
    file_prefix, options=None
)  # write关心save_counter和任何关于保存次数的信息, 用write保存的模型必须用read加载

 

上一篇:5.撤销操作


下一篇:用git恢复的时候无意间发现提示变成了git restore,原来git新增了git switch 和git restore命令