细粒度:MC_Loss源码笔记——The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification
综述
论文题目:《The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification》
期刊与时间:IEEE Transactions on Image Processing 2020 (TIP 2020)
论文地址:https://arxiv.org/pdf/2002.04264
源码地址(PyTorch版本):https://github.com/dongliangchang/Mutual-Channel-Loss
针对领域:细粒度图像分类(FGVC)
网络结构
以resnet为例
class model_bn(nn.Module):
def __init__(self, feature_size=512,classes_num=200):
super(model_bn, self).__init__()
# 定义特征提取网络,删掉原resnet中的全局平均池化和全连接层
# 注意,作者将最后一层输出的特征图通道数改了,改成了分类数*ζ,作者源码中以600为例(200*3)
# 只改变了layer4输出特征图的通道数,其他部分与原始的resnet相同
self.features = nn.Sequential(*list(net.children())[:-2])
# 全局最大池化
self.max = nn.MaxPool2d(kernel_size=14, stride=14)
# 特征图通道数。这里做了修改,与原通道数不一样
self.num_ftrs = 600*1*1
# 分类器,依次为批量标准化、线性回归、批量标准化ELU激活函数、线性回归
self.classifier = nn.Sequential(
nn.BatchNorm1d(self.num_ftrs),
#nn.Dropout(0.5),
nn.Linear(self.num_ftrs, feature_size),
nn.BatchNorm1d(feature_size),
nn.ELU(inplace=True),
#nn.Dropout(0.5),
nn.Linear(feature_size, classes_num),
)
def forward(self, x, targets):
# 首先图片经过特征提取,得到特征图
x = self.features(x)
# 之后经过MC_Loss模块,得到MC损失
if self.training:
MC_loss = supervisor(x, targets, height=14, cnum=3)
# 特征图依次经过全局最大池化,得到特征向量
x = self.max(x)
x = x.view(x.size(0), -1)
# 再经过分类器,得到网络的预测值
x = self.classifier(x)
# 求交叉熵损失
loss = criterion(x, targets)
# 如果是训练阶段,则返回预测值和预测损失的同时,还需要返回MC损失
if self.training:
return x, loss, MC_loss
# 如果是测试阶段,则只需要返回预测值与损失
else:
return x, loss
MC_Loss
def supervisor(x, targets, height, cnum):
# 首先得到掩模图
mask = Mask(x.size(0), cnum).cpu()
branch = x
# 将特征图改变形状,变成(batch,200*ζ,h*w),ζ表示多少特征图代表一类,作者以ζ=3为例
# 第二维度(dim=2)表示特征图中的特征数据
branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3))
# 将特征数据放入softmax,沿第二维度进行归一化操作(相当沿原来特征图上的数据进行扫描),对应论文中公式(7)后半段
branch = F.softmax(branch, 2)
# 再将特征图变回原来的形状
branch = branch.reshape(branch.size(0), branch.size(1), x.size(2), x.size(2))
# 将归一化后的数据传入CCMP模块,对应论文中公式(7)前半段
branch = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch)
# 特征图经过CCMP之后,通道数变为分类数,之后再转化一下形状
# 转化为(batch,200,w*h)
branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3))
# 之后首先对branch中的元素按第二维度求和,即对特征数据求和
# 之后再对所有通道取平均值,对应论文公式(6)
loss_2 = 1.0 - 1.0 * torch.mean(torch.sum(branch, 2)) / cnum# set margin = 3.0
# CWA模块:掩模图M与特征图相乘
branch_1 = x * mask
# CCMP模块,将所有特征图取相应类别的最大值,(对于每一类,3张压缩成1张),得到的特征图尺寸为(batch,200,h,w)
branch_1 = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch_1)
# 全局平均化,得到每一类的预测分数(h*w个值压缩成1个数),最终得到论文中公式(5)的结果
branch_1 = nn.AvgPool2d(kernel_size=(height,height))(branch_1)
# 压扁,便于后续取交叉熵损失
branch_1 = branch_1.view(branch_1.size(0), -1)
# 取交叉熵损失,对于论文中公式(4)
loss_1 = criterion(branch_1, targets)
# 返回损失
return [loss_1, loss_2]
计算CWA模块中的掩模图M:
# 得到CWA模块中的掩模图M
def Mask(nb_batch, channels):
# 假设三张特征图表示一个类别,即论文中的参数ζ为3
# 此时一组掩模M_i中由两个1,一个0组成
foo = [1] * 2 + [0] * 1
# 初始化总的M列表
bar = []
# 这里的200表示分类数
for i in range(200):
# 打乱初始化后M_i中的元素,表示随机生成M_i
random.shuffle(foo)
# 与总列表合并
bar += foo
# 按批次(batch)复制
bar = [bar for i in range(nb_batch)]
# 转换成array格式
bar = np.array(bar).astype("float32")
# 转换形状,转换成(batch,200*ζ,1,1),前两个维度中,掩模和特征图大小相同,便于后续的点乘操作
bar = bar.reshape(nb_batch, 200 * channels, 1, 1)
# 转换成tensor格式,之后放入显卡,再令其可求导
bar = torch.from_numpy(bar)
bar = bar.cuda()
bar = Variable(bar)
# 最后返回掩模M
return bar
CCMP模块:
class my_MaxPool2d(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False):
super(my_MaxPool2d, self).__init__()
# 最大池化的一系列参数,可以在定义my_MaxPool2d的同时引入
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
self.dilation = dilation
self.return_indices = return_indices
self.ceil_mode = ceil_mode
def forward(self, input):
# 将输入的1,3维度进行交换,即将通道维度与图片的宽w交换,得到(batch,w,h,600)的数据(以CUB数据集为例)
input = input.transpose(3,1)
# 最大池化,注意,此时池化核为(1, cnum)
# 相当于在原始的三张特征图中沿通道选择最大值,对应论文中公式(5)的中间部分(CCMP)
input = F.max_pool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
# 再将特征图维度变回去,变成正常的尺寸,即(1,200,h,w)
input = input.transpose(3,1).contiguous()
# 最后返回特征图
return input
以上内容仅是笔者的个人观点,若有错误,欢迎大家批评指正。
笔记原创,未经同意禁止转载!