every blog every motto: You can do more than you think.
0. 前言
训练过程中,停止,后续接着训练
1. 正文
1.1 保存信息
每个eopch以后需要保存后续接着训练的信息,信息包括,model、optimizer、epoch
for epoch in range(start_epoch,end_epoch):
for iter ,data in enumerate(dataloader):
pass
# -------------------------------------------------------
# 每个epoch 后保存checkpoint,以便断点继续训练
checkpoint = {
'eopch': epoch,
'model_state_dict': self.net.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}
torch.save(checkpoint, os.path.join(self.save_chpt,
'epoch_%d_loss_%3f.pth'.format(epoch, epoch_fuse_loss / ite_num_per_epoch)))
print('保存各参数完成,用于后续继续训练。')
# -------------------------------------------------------
1.2 继续训练
需要先实例化模型和优化器,然后进行如下操作
if self.subsequent_training: # 如果是断点继续上次训练
checkpoints = torch.load(os.path.join(self.save_chpt, 'xxx.pth'))
self.start_epoch = checkpoints['epoch'],
self.optimizer.load_state_dict(checkpoints['optimizer']),
self.net.load_state_dict(checkpoints['model'])
print('继续上次训练,各参数为:', checkpoints)
参考文献
[1] https://zhuanlan.zhihu.com/p/375461811
[2] https://www.zhihu.com/question/313486088?sort=created
[3] https://zhuanlan.zhihu.com/p/133250753
[4] https://www.jianshu.com/p/1cd6333128a1