Pytorch与深度学习自查手册4-训练、可视化、日志输出、保存模型
训练和验证(包含可视化、日志、保存模型)
初始化模型、dataloader都完善以后,正式进入训练部分。
训练部分包括:
-
及时的日志记录
-
tensorboard可视化log
-
输入
-
前向传播
-
loss计算
-
反向传播
-
权重更新
-
固定步骤进行验证
-
最佳模型的保存(+bad case输出)
日志记录
利用logging模块在控制台实时打印并及时记录运行日志。
from config import *
import logging # 引入logging模块
import os.path
class Logger:
def __init__(self,mode='w'):
# 第一步,创建一个logger
self.logger = logging.getLogger()
self.logger.setLevel(logging.INFO) # Log等级总开关
# 第二步,创建一个handler,用于写入日志文件
rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
log_path = os.getcwd() + '/Logs/'
log_name = log_path + rq + '.log'
logfile = log_name
fh = logging.FileHandler(logfile, mode=mode)
fh.setLevel(logging.DEBUG) # 输出到file的log等级的开关
# 第三步,定义handler的输出格式
formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
fh.setFormatter(formatter)
# 第四步,将logger添加到handler里面
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO) # 输出到console的log等级的开关
ch.setFormatter(formatter)
self.logger.addHandler(ch)
完整的训练流程
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
import sys
from tqdm import tqdm
import torch
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
mean_acc = torch.zeros(1).to(device)
optimizer.zero_grad()
data_loader = tqdm(data_loader)
for iteration, data in enumerate(data_loader):
batch, labels = data
pred = model(batch.to(device))
loss = loss_function(pred, labels.to(device))
loss.backward()
mean_loss = (mean_loss * iteration + loss.detach()) / (step + 1) # update mean losses
pred = torch.max(pred, dim=1)[1]
iter_acc=torch.eq(pred, labels.to(device)).sum()
mean_acc+=iter_acc
# 打印平均loss
if iteration % 50 == 0:
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
tags=["train_loss","train_accuracy","learning_rate"]
# tensorboard可视化
for tag, value in zip(tags, [mean_loss.item(), iter_acc.item(), optimizer.param_groups[0]["lr"]]):
tb_writer.add_scalars(tag, {'Train': value}, iteration)
return mean_loss.item(),mean_acc.item()
@torch.no_grad()
def evaluate(model, data_loader, device,best_acc=-1):
model.eval()
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 统计验证集样本总数目
num_samples = len(data_loader.dataset)
# 打印验证进度
data_loader = tqdm(data_loader, desc="validation...")
bad_case=[]
for step, data in enumerate(data_loader):
batch, labels = data
pred = model(batch.to(device))
pred = torch.max(pred, dim=1)[1]
tmp=torch.eq(pred, labels.to(device))
sum_num += tmp.sum()
bad_case.append((batch[~tmp],labels[~tmp]))
# 计算预测正确的比例
acc = sum_num.item() / num_samples
if best_acc<acc:
joblib.dump(bad_case,'bad_case.pkl')
return acc
def main(args,logger):
# 实例化模型
model=Model()
# 是否冻结权重
if args.freeze_layers:
print("freeze layers except fc layer.")
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# 写入日志
logger.logger.info('start training......\n')
tb_writer = SummaryWriter(log_dir=args.log_dir)
# 将模型写入tensorboard
init_input = torch.zeros((1, 3, 224, 224), device=args.device)
tb_writer.add_graph(model, init_input)
best_acc=0#最佳模型的指标
for epoch in range(args.epochs):
mean_loss,mean_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
# update learning rate
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
# tensorboard可视化
tb_writer.add_scalars("val_accuracy", {'Validation': acc}, epoch)
logger.logger.info('%d epoch train mean loss: %.2f \n'%(epoch,mean_loss))
logger.logger.info('%d epoch train mean acc: %.2f \n'%(epoch,mean_acc))
logger.logger.info('%d epoch validation acc: %.2f \n'%(epoch,acc))
if epoch % args.save_epoch==0:
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint-%d.pt'%epoch))
logger.logger.info('save model %d successed......\n'epoch)
# 保存最佳模型
if best_acc<acc:
best_acc=acc
logger.logger.info('best model in %d epoch, train mean acc: %.2f \n'%(epoch,mean_acc))
logger.logger.info('best model in %d epoch, validation acc: %.2f \n'%(epoch,acc))
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'best_checkpoint.pt'))
logger.logger.info('save best model successed......\n')
# 可视化图片预测结果
# add figure into tensorboard
fig = ...
if fig is not None:
tb_writer.add_figure("predictions vs. actuals",
figure=fig,
global_step=epoch)
# 可视化权重不断更新的直方图
# add conv1 weights into tensorboard
tb_writer.add_histogram(tag="conv1",
values=model.conv1.weight,
global_step=epoch)
tb_writer.add_histogram(tag="layer1/block0/conv1",
values=model.layer1[0].conv1.weight,
global_step=epoch)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.1)
parser.add_argument('--save_epoch', type=float, default=3)
parser.add_argument('--log_dir', type=float, default=3)
# 数据集所在根目录
data_root = "/home/data_set/"
parser.add_argument('--data-path', type=str, default=img_root)
#--freeze-layers #如果是True表示冻结除了全连接层以外的所有层的参数,在导入一些预训练的模型可以使用,可以加快模型训练
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
logger=Logger()
main(opt,logger)
模型保存和断点继续训练
在训练模型过程中,对模型进行保存是很重要的。
核心包括两个内容:
- 最优模型保存机制;
- 如何从断点加载模型继续训练。
最优模型通常通过设置测试间隔,根据选定的指标,选择训练过程中表现最优的模型进行保存。需要保存以下内容到checkpoint:
- 核心:模型参数、优化器参数、scheduler参数;
- 其他:训练epoch、模型超参数、模型评价指标。
只保存和加载模型的话,还有其他方式可参考:PyTorch之保存加载模型 - 简书 (jianshu.com)
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt'))
要实现断点继续训练只需要将模型上次保存的checkpoint加载进来,然后继续训练即可:
path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
lr_schedule = ThelrscheduleClass(*args, **kwargs)
model.load_state_dict(checkpoint['model_state_dict']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
for epoch in range(start_epoch, args.epochs + 1):
#……
train_mae = train(model, device, train_loader, optimizer,scheduler, criterion_fn)
#……
预测
@torch.no_grad()
def predict(model, data_loader, device):
model.eval()
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 统计验证集样本总数目
num_samples = len(data_loader.dataset)
# 打印验证进度
data_loader = tqdm(data_loader, desc="validation...")
res=[]
for step, batch in enumerate(data_loader):
pred = model(batch.to(device))
pred = torch.max(pred, dim=1)[1]
res.extend(pred)
return res