极简CenterNet(二)核心代码

  本节给出网络结构、损失函数、训练和验证部分等主要代码,并使用几种简单数据集进行了训练验证。

1,resnet.py

  原论文中centernet的主网络部分分别使用了hourglass,DLA,resnet三种网络,其中resnet是最简单的,我们的极简代码当然先从resnet18结构入手。
  代码见https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45/blob/master/nets/resnet.py,原封不动,我就不贴出来了。这是一个简单的网络结构:输入图像是B x 3 x 512 x 512的(1)然后在resnet基础上去掉最后的全连接层,经过layer1~4之后得到的特征图尺寸B x 512 x 16 x 16;(2)然后连接上三层反卷积层构成上采样层,使特征图上采样到B x 256 x 128 x 128的特征图;(3)然后连接三个分支,分别输出heatmap(热点图),regs(中心点偏移量),w_h_(宽高)。每个分支都是一个两层卷积结构,其中heatmap分支的输出是B x C x 128 x 128,C表示num_classes即检测的目标类别数,例如coco是80,它的每样本每通道的128 x 128图可以理解为对应类别在每个像素上的置信度;regs分支的输出是B x 2 x 128 x 128,表示预测的中心点的x方向和y方向偏移量,由于我们对原图片进行了4倍的缩小,所以再取整后会造成截断误差,这个regs就是为了补偿这个截断误差的,不是特别重要,不要它也不影响多少精度;w_h_分支的输出也是B x 2 x 128 x 128,表示检测框的宽高,这个当然是很重要的,再具体解释一下更好理解,其中每个128 x 128图中的点的两个通道的数值可以理解为“假设该点是目标中心时的检测框宽高”,至于这个点到底是不是真的目标中心,则由heatmap中该点的置信度来确定。
  懒得画图了,网上介绍这个原理的图很多,下面借用https://www.jianshu.com/p/d5d7cd7ad200上的一张,看看就明白了:
极简CenterNet(二)核心代码

图1. centernet主网络结构示意图,转自https://www.jianshu.com/p/d5d7cd7ad200

2, train.py

我在https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45基础上简化修改,使它更简洁:

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
from resnet import get_pose_net
from data_loader import CustomizeDataset
import cv2
import torch.nn.functional as F
from utils.utils import _tranpose_and_gather_feature
from utils.post_process import ctdet_decode, _nms
import matplotlib.pyplot as plt
import time
t0 = time.time()

train_dataset = CustomizeDataset(mode='train',num_classes=2)
val_dataset = CustomizeDataset(mode='val',num_classes=2)
kwargs = {"num_workers": 0, "pin_memory": True}
train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=20, **kwargs)
val_loader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=1, **kwargs)

def _neg_loss(preds, targets):
  pos_inds = targets.eq(1).float()
  neg_inds = targets.lt(1).float()

  neg_weights = torch.pow(1 - targets, 4)
  preds = torch.clamp(preds, min=1e-4, max=1 - 1e-4)
  pos_loss = torch.log(preds) * torch.pow(1 - preds, 2) * pos_inds
  neg_loss = torch.log(1 - preds) * torch.pow(preds, 2) * neg_weights * neg_inds

  num_pos = pos_inds.float().sum()
  pos_loss = pos_loss.sum()
  neg_loss = neg_loss.sum()
  loss = - (pos_loss + neg_loss) / num_pos
  return loss / len(preds)

def _reg_loss(regs, gt_regs, mask):
    mask = mask[:, :, None].expand_as(gt_regs).float()
    loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') / (mask.sum() + 1e-4) for r in regs)
    return loss / len(regs)

net = get_pose_net(num_layers=18, head_conv=64, num_classes=2)
net = net.cuda()
net.train()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)
losses_record = []
for epoch in range(10):
    for idx,data in enumerate(train_loader):
        img, heatmap, labels, gt_regs, gt_wh, inds, masks, bbox = data
        img, heatmap, gt_regs, gt_wh, inds, masks = \
            img.cuda(), heatmap.cuda(), gt_regs.cuda(), gt_wh.cuda(), inds.cuda(), masks.cuda()
        hmap, regs, wh = net(img)[0]
        hmap = torch.sigmoid(hmap)
        hmap_loss = _neg_loss(hmap, heatmap)
        regs = _tranpose_and_gather_feature(regs, inds)
        wh = _tranpose_and_gather_feature(wh, inds)
        reg_loss = _reg_loss(regs, gt_regs, masks)
        w_h_loss = _reg_loss(wh, gt_wh, masks)
        loss = 10*hmap_loss + 1 * reg_loss + 0.1 * w_h_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(idx,'hmap_loss:%.4f, reg_loss:%.4f, w_h_loss:%.4f, loss:%.4f'%(
            hmap_loss.item(), reg_loss.item(),w_h_loss.item(),loss.item()),'time:%.1f'%(time.time()-t0))
        losses_record.append([hmap_loss.item(), reg_loss.item(),w_h_loss.item(),loss.item()])
        if idx%10==0:
            plt.figure();plt.imshow(_nms(hmap)[0,0].data.cpu().numpy())
            plt.figure();plt.imshow(heatmap[0,0].data.cpu().numpy())
            plt.figure();plt.imshow(_nms(hmap)[0,1].data.cpu().numpy())
            plt.figure();plt.imshow(heatmap[0,1].data.cpu().numpy())
losses_record = np.array(losses_record)
plt.figure();plt.semilogy(losses_record);plt.legend(['hmap loss','regs loss','wh loss','total loss'])

def IOU(box1,box2):
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    
    interArea = max(0,(xB - xA + 1)) * max(0,(yB - yA + 1))
    box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
    box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
    iou = interArea / float(box1Area + box2Area - interArea)
    return iou
    
net.eval()
ious = []
with torch.no_grad():
    for idx,data in enumerate(val_loader):
        img, heatmap, labels, gt_regs, gt_wh, inds, masks, bbox = data
        img, heatmap, gt_regs, gt_wh, inds, masks = \
            img.cuda(), heatmap.cuda(), gt_regs.cuda(), gt_wh.cuda(), inds.cuda(), masks.cuda()
        hmap, regs, wh  = net(img)[0]

        dets = ctdet_decode(hmap, regs, wh )
        dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])[0]
        
        image = img[0].permute(1,2,0).data.cpu().numpy()
        image = (image/2 + 0.5)*255
        image = image.astype('uint8')[:,:,::-1]
        image = image.copy()
        
        bbox = bbox[0].data.cpu().numpy()
        labels = labels[0].data.cpu().numpy()
        pred_box = []
        for label in range(heatmap.shape[1]):
            bbox_l = bbox[labels==label].copy()
            det_l = dets[dets[:,5]==label].copy()
            box_num = bbox_l.shape[0]
            for n in range(box_num):
                det = det_l[n]
                det[4] = det[4]*100
                det[:4] = det[:4]*4
                det = det.round().astype('int')
                pred_box.append(det[:4])
                box = bbox_l[n]
                image = cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 220, 0), 2)
                image = cv2.rectangle(image, (det[0], det[1]), (det[2], det[3]), (0, 0, 220), 2)
                iou = IOU(box,det[:4])   #TODO:此计算方法对每类多目标时不正确,待后续修改
                ious.append(iou)
        cv2.imwrite('fig2/%06d.jpg'%(999-idx),image)        
mean_iou = np.mean(ious)
print('mean iou: %.4f'%mean_iou)

  其他一些utils中的函数也请参见https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45,不做修改。
  这其中关于heatmap loss,也就是代码中的_neg_loss函数需要说明一下:这个损失是centernet的关键,开始不是很容易理解,实际上它就是对focal loss的再升级,在focal loss的基础上加上了对标注框热点图附近的衰减。但需要注意的是它的目的不是引导preds和targets一致(这和通常的损失函数不一样),它的目的是使preds趋向于单个中心点为1,其他点为0的输出图。不信可以试一下,如果preds和targets都是heatmap时,heatmap loss并不是0,只有当preds是中心单点为1其他为0,而targets是heatmap时,heatmap loss才为0。个人认为这是一个大坎,理解了这一点之后就很容易理解别的了。

3,效果

  我们先用几种简单的数据集来检验,训练集都用800张图片,验证集用200张图片。由于数据非常简单,我们不用mAP指标(将会太高),我们用mIOU指标来验证效果。其中单个五星数据集和单个正方+单个四芒星数据集验证结果中的检测框(红色)和标注框(绿色)的情况如下图(也有个别IOU低一些的这里没有画)。
极简CenterNet(二)核心代码

图2.两种自定义数据集的标注框(绿色)和检测框(红色)示例

  具体试验数据对比见下表:

数据集 lr 训练速度 mIOU 备注
单个正方形,不旋转 1e-3 42fps 0.8740
单个正方形,随机旋转 1e-3 42fps 0.0000 不收敛
单个正方形,随机旋转 1e-4 42fps 0.9525
单个五角星,随机旋转 1e-4 42fps 0.9750
单个五角星+单个五边形,随机旋转 1e-4 40fps 0.9638
两个五角星,随机旋转 1e-4 40fps 0.9668

极简CenterNet(二)核心代码

图3.几种情况下损失函数各部分随训练步数(每步20个样本)的变化

从损失函数变化图中我们可以看出:

  • regs loss只在训练初期会下降,后期一直不再变化,说明中心点偏移损失regs loss确实起点作用,但作用不大;
  • hmap loss和wh loss表现出轮动的现象:hmap loss初期就开始下降,但wh loss初期不动,当hmap loss下降到一定程度后,也就是说当中心点找的比较靠谱后wh loss才开始下降;而wh loss后期基本没法再优化后,wh loss还会持续下降,这是因为此时输出的heatmap图开始继续提高中心点的聚焦度,并提高中心点的置信度,最终目的是趋向于中心点1,而其他点0。
  • 对于单正方形旋转数据集与不旋转数据集相比,wh损失会高很多,这是因为我们的标注框是根据外接矩形标注的,而外接矩形这个几何概念网络较难掌握,对于一个倾斜的正方形,网络计算它的外接矩形会比较困难。
  • 学习率很重要,过大学习率会导致heatmap损失难以继续下降,可能是输出的hmap已经和输入的heatmap接近了,此时必须用更小的学习率才能学习。

通过以上分析,其实我们可以得出几种网络的改进思路:

  • regs loss初期可以不用,最后的这个网络分支都可以锁死不用,到最后几轮再打开训练一下即可,这样可以节省训练时间。
  • heatmap在训练初期可使用,后期可以去掉高斯圆,就只用中心单点,效果应该更好一些。

关于heatmap的更多讨论,以及其他一些讨论,请见下节。
这几天腰疼病犯了,没法久坐,所以匆匆写了写,文字没有润色修剪,可能说的不是很好懂,各位将就看看吧。

上一篇:matlab画热力网格图


下一篇:mapbox热力图属性