本文主要内容源于:
======================================================================
本地加载模型
用于推理验证
针对仅推理场景可以使用load_checkpoint
把参数直接加载到网络中,以便进行后续的推理验证。
示例代码如下:
resnet = ResNet50() load_checkpoint("resnet50-2_32.ckpt", net=resnet) dateset_eval = create_dataset(os.path.join(mnist_path, "test"), 32, 1) # define the test dataset loss = CrossEntropyLoss() model = Model(resnet, loss, metrics={"accuracy"}) acc = model.eval(dataset_eval)
-
load_checkpoint
方法会把参数文件中的网络参数加载到模型中。加载后,网络中的参数就是CheckPoint保存的。 -
eval
方法会验证训练后模型的精度。
用于迁移学习
针对任务中断再训练及微调(Fine Tune)场景,可以加载网络参数和优化器参数到模型中。
示例代码如下:
# return a parameter dict for model param_dict = load_checkpoint("resnet50-2_32.ckpt") resnet = ResNet50() opt = Momentum(resnet.trainable_params(), 0.01, 0.9) # load the parameter into net load_param_into_net(resnet, param_dict) # load the parameter into optimizer load_param_into_net(opt, param_dict) loss = SoftmaxCrossEntropyWithLogits() model = Model(resnet, loss, opt) model.train(epoch, dataset)
-
load_checkpoint
方法会返回一个参数字典。 -
load_param_into_net
会把参数字典中相应的参数加载到网络或优化器中。