将YOLOv3及以上的网络中的BCE loss更改为Focal loss
loss函数分为三部分,位置损失、置信度损失、类别损失,此处只需要将置信度损失更换为Focal loss,具体原理请仔细理解置信度损失的含义。
YOLOX链接:https://link.zhihu.com/?target=https%3A//github.com/Megvii-BaseDetection/YOLOX
1 找到置信度预测损失计算位置loss_obj,并进行替换(位置在386-405行左右)
loss_iou:定位损失;loss_obj:置信度预测损失;loss_cls:预测损失
loss_iou = (
self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
).sum() / num_fg
#loss_obj = (
# self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
#).sum() / num_fg
loss_obj = (
self.focal_loss(obj_preds.sigmoid().view(-1, 1), obj_targets)
).sum() / num_fg
loss_cls = (
self.bcewithlog_loss(
cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
)
).sum() / num_fg
2 创建focal_loss方法,放到def get_l1_target(…)之前即可,代码如下:
def focal_loss(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.eq(0).float()
pos_loss = torch.log(pred+1e-5) * torch.pow(1 - pred, 2) * pos_inds * 0.75
neg_loss = torch.log(1 - pred+1e-5) * torch.pow(pred, 2) * neg_inds * 0.25
loss = -(pos_loss + neg_loss)
return loss