def accuracy(output, label, topk=(1,)):
maxk = max(topk)
batch_size = output.size(0)
# 在输出结果中取前maxk个最大概率作为预测结果,并获取其下标,当topk=(1, 5)时取5就可以了。
_, pred = torch.topk(output, k=maxk, dim=1, largest=True, sorted=True)
# 将得到的k个预测结果的矩阵进行转置,方便后续和label作比较
pred = pred.T
# 将label先拓展成为和pred相同的形状,和pred进行对比,输出结果
correct = torch.eq(pred, label.contiguous().view(1,-1).expand_as(pred))
# 例:
# 若label为:[1,2,3,4], topk = (1, 5)时
# 则label.contiguous().view(1,-1).expand_as(pred)为:
# [[1, 2, 3, 4],
# [1, 2, 3, 4],
# [1, 2, 3, 4],
# [1, 2, 3, 4],
# [1, 2, 3, 4]]
res = []
for k in topk:
# 取前k个预测正确的结果进行求和
correct_k = correct[:k].contiguous().view(-1).float().sum(dim=0, keepdim=True)
# 计算平均精度, 将结果加入res中
res.append(correct_k*100/batch_size)
return res
当topk=(1, 5)时同时返回acc1, 和acc5
在验证时的实例代码:
class AverageMeter():
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.sum = 0
self.avg = 0
self.count = 0
def update(self, val, n):
self.sum += float(val)*n
self.count += n
self.avg = self.sum / self.count
def validation(val_dataloader,
num_batch_val,
criterion,
model,
device,
total_epochs,
logger,
debug_step = 100):
model.eval()
acc1_val = AverageMeter()
acc5_val = AverageMeter()
loss_val = AverageMeter()
start_time = time.time()
with torch.no_grad():
for batch_id, data in enumerate(val_dataloader):
image = data[0]
label = data[1]
image = Variable(image.to(device), requires_grad=False)
label = Variable(label.to(device), requires_grad=False)
image = image.flatten(1)
# logger.info(f"the image size :{image.size()}")
# logger.info(f"the label size :{label.size()}")
output = model(image)
loss = criterion(output, label)
acc1, acc5 = accuracy(output=output, label=label, topk=(1, 5))
loss_val.update(loss.data, image.size(0))
acc1_val.update(acc1[0], image.size(0))
acc5_val.update(acc5, image.size(0))
if batch_id % debug_step == 0 or batch_id == num_batch_val:
logger.info(f"Val Step:[{batch_id:03d}/{num_batch_val:03d}], "+
f"Avg Loss:{loss_val.avg:.4f}, "+
f"Avg Acc1:{acc1_val.avg:.4f}, "+
f"Avg Acc5:{acc5_val.avg:.4f}")
end_time = time.time()
return acc1_val.avg, acc5_val.avg, loss_val.avg, end_time-start_time