使用MindSpore保存模型

本次演示使用MindSpore。所以首先需要安装MindSpore。大家可以参考MindSpore官网页面安装:https://www.mindspore.cn/install

在使用MindSpore进行模型训练过程中,首先添加检查点(CheckPoint)保存模型的参数,这样在后面进行推理及再训练的时候使用很方便。

CheckPoint:MindSpore存储了所有训练参数值的二进制文件。是采用的Google的Protocol Buffers机制,在开发语言方面边界,具有良好的可扩展性。CheckPoint的protocol格式定义在mindspore/ccsrc/utils/checkpoint.proto中。

  1. 保存CheckPoint文件方法:

    在模型训练的过程中,需要使用Callback机制传入回调函数ModelCheckpoint对象,可以保存模型参数,生成CheckPoint文件。

    通过CheckpointConfig对象可以设置CheckPoint的保存策略。保存的参数分为网络参数和优化器参数。

    ModelCheckpoint提供默认配置策略,方便用户快速上手。具体用法如下:

    from mindspore.train.callback import ModelCheckpoint

    ckpoint_cb = ModelCheckpoint()

    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

  2. 可以根据个人的具体需求对CheckPoint进行配置。基本的用法如下:

    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)

    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)

    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

  3. 以上示例代码中,首先需要初始化一个CheckpointConfig类对象,用来设置保存策略。

    save_checkpoint_steps:表示每隔多少个step保存一次。

    keep_checkpoint_max:表示最多保留CheckPoint文件的数量。

    prefix:表示生成CheckPoint文件的前缀名。

    directory:表示存放文件的目录。

    创建一个ModelCheckpoint对象把它传递给model.train方法,就可以在训练过程中使用CheckPoint功能了。

    生成的CheckPoint文件如下:

    resnet50-graph.meta # 编译后的计算图

    resnet50-1_32.ckpt  # CheckPoint文件后缀名为'.ckpt'

    resnet50-2_32.ckpt  # 文件的命名方式表示保存参数所在的epoch和step数

    resnet50-3_32.ckpt  # 表示保存的是第3个epoch的第32个step的模型参数

    ...

  4. 用户使用相同的前缀名,运行多次训练脚本,可能会生成同名CheckPoint文件。MindSpore为方便用户区分每次生成的文件,会在用户定义的前缀后添加”_”和数字加以区分。如果想要删除.ckpt文件时,请同步删除.meta 文件。

    例:resnet50_3-2_32.ckpt 表示运行第3次脚本生成的第2个epoch的第32个step的CheckPoint文件。

    使用MindSpore保存模型

  5. 5

    CheckPoint配置策略:

    MindSpore有两种保存CheckPoint策略:迭代策略和时间策略,可以通过创建CheckpointConfig对象设置相应策略。 CheckpointConfig中有四个参数可以自定义设置:

    (1)save_checkpoint_steps:表示每隔多少个step保存一个CheckPoint文件,默认值为1。

    (2)save_checkpoint_seconds:表示每隔多少秒保存一个CheckPoint文件,默认值为0。

    (3)keep_checkpoint_max:表示最多保存多少个CheckPoint文件,默认值为5。

    (4)keep_checkpoint_per_n_minutes:表示每隔多少分钟保留一个CheckPoint文件,默认值为0。

    (5)save_checkpoint_steps和keep_checkpoint_max为迭代策略,根据训练迭代的次数进行配置。 save_checkpoint_seconds和keep_checkpoint_per_n_minutes为时间策略,根据训练的时长进行配置。

    两种策略不能同时使用,迭代策略优先级高于时间策略,当同时设置时,只有迭代策略可以生效。当参数显示设置为None时,表示放弃该策略。在迭代策略脚本正常结束的情况下,会默认保存最后一个step的CheckPoint文件。

上一篇:SVN中服务器地址变更


下一篇:从0到1Flink的成长之路(二十一)-Flink+Kafka:End-to-End Exactly-Once