takeaway:
OCNet是与DAnet同期的文章,也是self-attention的应用,不同之处是探索了与psp及aspp等模块的结合效果
https://github.com/openseg-group/OCNet.pytorch
这里的ocp就是self-attention
ocp就是self-attention, Pyramid-OC是类似于psp的方法先划分成不同的区域后单独做 self-attention,后面的结果显示这样没有明显的提升,单也没有下降,asp-oc是用ocp代替了原来的gap,效果还有提升
class _SelfAttentionBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to downsample the input feature maps (save memory cost)
Return:
N X C X H X W
position-aware context features.(w/o concate or add with the input)
'''
def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1):
super(_SelfAttentionBlock, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.out_channels = out_channels
self.key_channels = key_channels
self.value_channels = value_channels
if out_channels == None:
self.out_channels = in_channels
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_key = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
InPlaceABNSync(self.key_channels),
)
self.f_query = self.f_key
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
kernel_size=1, stride=1, padding=0)
self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
def forward(self, x):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)
value = self.f_value(x).view(batch_size, self.value_channels, -1)
value = value.permute(0, 2, 1)
query = self.f_query(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_key(x).view(batch_size, self.key_channels, -1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.value_channels, *x.size()[2:])
context = self.W(context)
if self.scale > 1:
if torch_ver == '0.4':
context = F.upsample(input=context, size=(h, w), mode='bilinear', align_corners=True)
elif torch_ver == '0.3':
context = F.upsample(input=context, size=(h, w), mode='bilinear')
return context