CBAM:通道注意力+空间注意力【附Pytorch实现】

论文地址:https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf

CBAM:通道注意力+空间注意力【附Pytorch实现】

1、动机

卷积操作是通过混合通道和空间两个维度的信息来间特征提取的。在注意力方面,SE仅关注了通道注意力,没考虑空间方面的注意力。因此,本文提出了 CBAM——一种同时关注通道和空间注意力的卷积模块,可以用于CNNs架构中,以提升feature map的特征表达能力。

2、方法

CBAM的整体架构如上面图1所示,其包括两块内容:Channel Attention Module、Spatial Attention Module,也即通道注意力CAM、空间注意力SAM。假设CABM的输入feature map为CBAM:通道注意力+空间注意力【附Pytorch实现】,CBAM先用CAM得到1D通道注意力map CBAM:通道注意力+空间注意力【附Pytorch实现】,再用SAM得到2D空间注意力map CBAM:通道注意力+空间注意力【附Pytorch实现】,该过程公式表示如下:

CBAM:通道注意力+空间注意力【附Pytorch实现】   (1)

CAM和SAM的架构如图2所示:

CBAM:通道注意力+空间注意力【附Pytorch实现】

CAM:对于输入的feature map F,首先在每个空间位置上应用MaxPooling、AvgPooling,得到两个C*1*1的向量,然后分别送入一个共享的包含两层FC的MLP,最后最像素相加融合,经过一个激活函数,得到通道注意力map,其公式表达为:

CBAM:通道注意力+空间注意力【附Pytorch实现】   (2)

SAM:CAM输出的feature map,将送入SAM。首先在每个通道上应用MaxPooling、AvgPooling,得到两个1*H*W的feature map,然后按通道concat起来,送入一个标准卷积层,经过激活函数之后就得到了空间注意力map,其公式表达为:

CBAM:通道注意力+空间注意力【附Pytorch实现】   (3)

CAM和SAM,一个关注“what”,一个关注“where”,两者可以并行或者串行使用。

3、Pytorch实现

CBAM包含了个子模块:CAM和SAM,使用时,分别实例化它们,然后顺序应用在某个feature map之后即可。下面给出其Pytorch代码:

import torch
from torch import nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.register_buffer()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

 

上一篇:[SAM学习笔记]


下一篇:loj#6041. 「雅礼集训 2017 Day7」事情的相似度(SAM set启发式合并 二维数点)