文章内容:如何在YOLOX官网代码中修改–定位损失
环境:pytorch1.8
修改内容:
(1)置信度预测损失更换:二元交叉熵损失替换为FocalLoss或者VariFocalLoss
(2)定位损失更换:IOU损失替换为GIOU、CIOU、EIOU以及a-IOU系列
提示:使用之前可以先了解YOLOX及上述损失函数原理
参考链接:
YOLOX官网链接:https://github.com/Megvii-BaseDetection/YOLOX
YOLOX原理解析(Bubbliiiing大佬版):https://blog.csdn.net/weixin_44791964/article/details/120476949
FocalLoss损失解析:https://cyqhn.blog.csdn.net/article/details/87343004
VariFocalLoss损失解析:https://blog.csdn.net/weixin_42096202/article/details/108567189
GIOU、CIOU、EIOU等:https://blog.csdn.net/neil3611244/article/details/113794197
a-IOU:https://blog.csdn.net/wjytbest/article/details/121513560
使用方法:直接替换即可
代码修改过程:
1、IOUloss等其他系列更改
修改位置:只需要在YOLOX-main/yolox/models/losses.py中更改,如“loss_type=ciou”
注意:没有DIOU与focal_EIOU(出现精度大降,直至为0,所以先删除了,后续补上!)
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target): #(x_center.y_center,w,h)
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) # top left左上角坐标
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) # 右下角坐标
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en # 相交面积
area_u = area_p + area_g - area_i # 相并面积
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2 ############### 注意:这里不是平常IOULOSS,iou多了一个平方,类似a-iou(2>>>>3)
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) # 两个框的外接矩形框的左上角
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) # 两个框的外接矩形框的右下角
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
elif self.loss_type == "ciou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2) # 外接矩形对角线的平方
center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2)) # 两个框中心距离平方
v = (4 / math.pi ** 2) * torch.pow(torch.atan(target[:, 2] / torch.clamp(target[:, 3], min = 1e-6)) -
torch.atan(pred[:, 2] / torch.clamp(pred[:, 3], min = 1e-6)), 2)
with torch.no_grad():
alpha = v / (v - iou + 1)
ciou = iou - (center_dis / convex_dis.clamp(1e-16) + alpha * v)
loss = 1 - ciou.clamp(min=-1.0, max=1.0)
elif self.loss_type == "eiou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2) # 外接矩形对角线的平方
center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2)) # 两个框中心距离平方
dis_w = torch.pow(pred[:, 2]-target[:, 2], 2) # 两个框的w欧式距离
dis_h = torch.pow(pred[:, 3]-target[:, 3], 2) # 两个框的h欧式距离
C_w = torch.pow(c_tl[:, 0]-c_br[:, 0], 2) # 包围框的w平方
C_h = torch.pow(c_tl[:, 1]-c_br[:, 1], 2) # 包围框的h平方
eiou = iou - (center_dis / convex_dis.clamp(1e-16)) - (dis_w / C_w.clamp(1e-16)) - (dis_h / C_h.clamp(1e-16))
loss = 1 - eiou.clamp(min=-1.0, max=1.0)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
2、a-IOU系列修改(其实a-iou与iou就是多了个a次方)
注意:为了不影响原代码,重新在losses.py中续写一个类似IOUloss的alpha_IOUloss
步骤一:在losses.py中续写一个alpha_IOUloss的类(这里只有IOU+GIOU+CIOU,可以参考IOUloss自行补上)
class alpha_IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="ciou",alpha=3):
super(alpha_IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
self.alpha = alpha
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** self.alpha ############### 2>>>>3(a-iou)
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou** self.alpha - ((area_c - area_u) / area_c.clamp(1e-16))** self.alpha
loss = 1 - giou.clamp(min=-1.0, max=1.0)
elif self.loss_type == "ciou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2) + 1e-16 # convex diagonal squared
center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2)) # center diagonal squared
v = (4 / math.pi ** 2) * torch.pow(torch.atan(target[:, 2] / torch.clamp(target[:, 3], min = 1e-6)) -
torch.atan(pred[:, 2] / torch.clamp(pred[:, 3], min = 1e-6)), 2)
with torch.no_grad():
beat = v / (v - iou + 1)
ciou = iou** self.alpha - (center_dis** self.alpha / convex_dis** self.alpha + (beat * v)** self.alpha)
loss = 1 - ciou.clamp(min=-1.0, max=1.0)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
步骤二:在yolo_head.py中追加调用alpha_IOUloss
from .losses import IOUloss, alpha_IOUloss
步骤三:实例化alpha_IOUloss,并注释之前的IOUloss
#self.iou_loss = IOUloss(reduction="none")
self.iou_loss = alpha_IOUloss(reduction="none")
效果:根据个人数据集而定。不同场景不同效果,对我的数据集基本上都没有原来的好,但也差距不大,降了一点点。
以上代码链接:
链接:https://pan.baidu.com/s/1XQAvT2VdtMMLoUo4FY9v8A
提取码:hd9f