文章内容:如何在YOLOX官网代码中添加ASFF模块
环境:pytorch1.8
修改内容:
(1)在PAFPN尾部添加ASFF模块(YOLOX-s等版本)
(2)在FPN尾部添加ASFF模块(YOLOX-Darknet53版本)
参考链接:
论文链接:https://arxiv.org/pdf/1911.09516v2.pdf
ASFF原理及代码参考:https://blog.csdn.net/weixin_44119362/article/details/114289607
示意图如下:
使用方法:直接在PAFPN或FPN尾部添加即可(可自动进行维度匹配,不需要修改)
代码修改过程:
1、在YOLOXS版本的PAFPN后添加ASFF模块
(注意:这里是PAFPN该版本用于YOLOv5版的PAFPN中,不能用于YOLOv3的FPN)
步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF.py文件,内容如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class ASFF(nn.Module):
def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
"""
multiplier should be 1, 0.5
which means, the channel of ASFF can be
512, 256, 128 -> multiplier=0.5
1024, 512, 256 -> multiplier=1
For even smaller, you need change code manually.
"""
super(ASFF, self).__init__()
self.level = level
self.dim = [int(1024*multiplier), int(512*multiplier),
int(256*multiplier)]
# print(self.dim)
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(int(512*multiplier), self.inter_dim, 3, 2)
self.stride_level_2 = Conv(int(256*multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(
1024*multiplier), 3, 1)
elif level == 1:
self.compress_level_0 = Conv(
int(1024*multiplier), self.inter_dim, 1, 1)
self.stride_level_2 = Conv(
int(256*multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(512*multiplier), 3, 1)
elif level == 2:
self.compress_level_0 = Conv(
int(1024*multiplier), self.inter_dim, 1, 1)
self.compress_level_1 = Conv(
int(512*multiplier), self.inter_dim, 1, 1)
self.expand = Conv(self.inter_dim, int(
256*multiplier), 3, 1)
# when adding rfb, we use half number of channels to save memory
compress_c = 8 if rfb else 16
self.weight_level_0 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_levels = Conv(
compress_c*3, 3, 1, 1)
self.vis = vis
def forward(self, x): #l,m,s
"""
#
256, 512, 1024
from small -> large
"""
x_level_0=x[2] #最大特征层
x_level_1=x[1] #中间特征层
x_level_2=x[0] #最小特征层
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(
x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=4, mode='nearest')
x_level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(
x_level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat(
(level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\
level_1_resized * levels_weight[:, 1:2, :, :] +\
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
步骤二:在YOLOX-main/yolox/models/yolo_pafpn.py中调用ASFF模块
(1)导入
from .ASFF import ASFF
(2)在init中实例化
# ############ 2、实例化ASFF
self.asff_1 = ASFF(level = 0, multiplier = width)
self.asff_2 = ASFF(level = 1, multiplier = width)
self.asff_3 = ASFF(level = 2, multiplier = width)
def forward(self, input):
(3)直接在PAFPN输出outputs后接上ASFF模块
outputs = (pan_out2, pan_out1, pan_out0)
# asff
pan_out0 = self.asff_1(outputs)
pan_out1 = self.asff_2(outputs)
pan_out2 = self.asff_3(outputs)
outputs = (pan_out2, pan_out1, pan_out0)
return outputs
2、在YOLOX-Darknet53的FPN后添加ASFF模块
(注意:这里是用于YOLOv3的FPN)
步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF_darknet.py文件,内容如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_blocks import BaseConv
# 输入的是从FPN得到的特征
# 输出是要给head的特征,与之前不变,注意单个ASFF只能输出一个特征,level=0对应最底层的特征,这里是512*20*20,尺度大小 level_0 < level_1 < level_2
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFF, self).__init__()
self.level = level
self.dim = [512, 256, 128]
self.inter_dim = self.dim[self.level]
if level==0:
self.stride_level_1 = self._make_cbl(256, self.inter_dim, 3, 2)
self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
self.expand = self._make_cbl(self.inter_dim, 512, 3, 1) # 输出是要给head的特征,与之前不变512-512
elif level==1:
self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
self.expand = self._make_cbl(self.inter_dim, 256, 3, 1) # 输出是要给head的特征,与之前不变256-256
elif level==2:
self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
self.compress_level_1 = self._make_cbl(256, self.inter_dim, 1, 1)
self.expand = self._make_cbl(self.inter_dim, 128, 3, 1) # 输出是要给head的特征,与之前不变128-128
compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory
self.weight_level_0 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
self.vis= vis
def _make_cbl(self, _in, _out, ks, stride):
return BaseConv(_in, _out, ks, stride, act="lrelu")
def forward(self, x_level_0, x_level_1, x_level_2): # 输入3个维度(512*20*20,256*40*40,128*80*80),输出也是
if self.level==0:
level_0_resized = x_level_0 # (512*20*20)
level_1_resized = self.stride_level_1(x_level_1) # (256*40*40->512*20*20)
level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1) # (128*80*80->128*40*40)
level_2_resized = self.stride_level_2(level_2_downsampled_inter) # (128*40*40->512*20*20)
elif self.level==1:
level_0_compressed = self.compress_level_0(x_level_0) # (512*20*20->256*20*20)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') # (256*20*20->256*40*40)
level_1_resized =x_level_1 # (256*40*40)
level_2_resized =self.stride_level_2(x_level_2) # (128*80*80->256*40*40)
elif self.level==2:
level_0_compressed = self.compress_level_0(x_level_0) # (512*20*20->128*20*20)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') # (128*20*20->128*80*80)
level_1_compressed = self.compress_level_1(x_level_1) # (256*40*40->128*40*40)
level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest') # (128*40*40->128*80*80)
level_2_resized =x_level_2 # (128*80*80)
level_0_weight_v = self.weight_level_0(level_0_resized) #
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\
level_1_resized * levels_weight[:,1:2,:,:]+\
level_2_resized * levels_weight[:,2:,:,:]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
步骤二:在YOLOX-main/yolox/models/yolo_fpn.py中调用ASFF模块
(1)导入
from .ASFF_darknet import ASFF ### 1、导入
(2)实例化ASFF对象
####################### 2、实例化ASFF
self.assf_5 = ASFF(level = 0)
self.assf_4 = ASFF(level = 1)
self.assf_3 = ASFF(level = 2)
########################
def _make_cbl(self, _in, _out, ks):
return BaseConv(_in, _out, ks, stride=1, act="lrelu")
(3)在outputs后直接添加asff
outputs = (out_dark3, out_dark4, x0) # 特征图尺度逐渐变小(128,256,512) ### 该行为初始的FPN输出,使用ASFF则注释掉
################################################
# 3、对FPN特征金字塔进行ASFF操作,注释掉原FPN输出outpus
out_assf_5 = self.assf_5(x0, out_dark4, out_dark3)
out_assf_4 = self.assf_4(x0, out_dark4, out_dark3)
out_assf_3 = self.assf_3(x0, out_dark4, out_dark3)
outputs = (out_assf_3, out_assf_4, out_assf_5)
#################################################
return outputs
效果:根据个人数据集而定。对我的数据集没变化。
权重大小变化:yoloxs(68.8M->110M)
速度变化:有所下降
上述代码链接:
链接:https://pan.baidu.com/s/1ykfb-YHpJaLj4sQpMsCIKw
提取码:qrvg