pytorch计算分类验证精度acc1,acc5代码

def accuracy(output, label, topk=(1,)):
    maxk = max(topk)
    batch_size = output.size(0)
	
	# 在输出结果中取前maxk个最大概率作为预测结果,并获取其下标,当topk=(1, 5)时取5就可以了。
    _, pred = torch.topk(output, k=maxk, dim=1, largest=True, sorted=True)
    
    # 将得到的k个预测结果的矩阵进行转置,方便后续和label作比较
    pred = pred.T
    # 将label先拓展成为和pred相同的形状,和pred进行对比,输出结果
    correct = torch.eq(pred, label.contiguous().view(1,-1).expand_as(pred))
	# 例:
	# 若label为:[1,2,3,4], topk = (1, 5)时
	# 则label.contiguous().view(1,-1).expand_as(pred)为:
	# [[1, 2, 3, 4],
	#  [1, 2, 3, 4],
	#  [1, 2, 3, 4],
	#  [1, 2, 3, 4],
	#  [1, 2, 3, 4]]
	
    res = []

    for k in topk:
    	# 取前k个预测正确的结果进行求和
        correct_k = correct[:k].contiguous().view(-1).float().sum(dim=0, keepdim=True)
        # 计算平均精度, 将结果加入res中
        res.append(correct_k*100/batch_size)
    
    return res

当topk=(1, 5)时同时返回acc1, 和acc5

在验证时的实例代码:

class AverageMeter():
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.sum = 0
        self.avg = 0
        self.count = 0
    
    def update(self, val, n):
        self.sum += float(val)*n
        self.count += n
        self.avg = self.sum / self.count

def validation(val_dataloader,
               num_batch_val,
               criterion,
               model,
               device,
               total_epochs,
               logger,
               debug_step = 100):
            
            model.eval()
            acc1_val = AverageMeter()
            acc5_val = AverageMeter()
            loss_val = AverageMeter()

            start_time = time.time()
            with torch.no_grad():
                for batch_id, data in enumerate(val_dataloader):
                    image = data[0]
                    label = data[1]
                    image = Variable(image.to(device), requires_grad=False)
                    label = Variable(label.to(device), requires_grad=False)
                    image = image.flatten(1)
                    # logger.info(f"the image size :{image.size()}")
                    # logger.info(f"the label size :{label.size()}")
                    output = model(image)
                    loss = criterion(output, label)
                    acc1, acc5 = accuracy(output=output, label=label, topk=(1, 5))

                    loss_val.update(loss.data, image.size(0))
                    acc1_val.update(acc1[0], image.size(0))
                    acc5_val.update(acc5, image.size(0))

                    if batch_id % debug_step == 0 or batch_id == num_batch_val:
                        logger.info(f"Val Step:[{batch_id:03d}/{num_batch_val:03d}], "+
                                    f"Avg Loss:{loss_val.avg:.4f}, "+
                                    f"Avg Acc1:{acc1_val.avg:.4f}, "+
                                    f"Avg Acc5:{acc5_val.avg:.4f}")
            end_time = time.time()

            return acc1_val.avg, acc5_val.avg, loss_val.avg, end_time-start_time
上一篇:操作sql - 类型初始值设定项引发异常


下一篇:webpack 支持的模块方法