PPM-Pyramid pooling module

Pyramid pooling 方法出自 2017CVPR,原文地址https://arxiv.org/pdf/1612.01105.pdf
一、PPM的结构
PPM-Pyramid pooling module
原文中采用4种不同金字塔尺度,金字塔池化模块的层数和每层的size是可以修改的。论文中金字塔池化模块是4层,每层的size分别是1×1,2×2,3×3,6×6。

首先,对特征图分别池化到目标size,然后对池化后的结果进行1×1卷积将channel减少到原来的1/N,这里N就为4。接着,对上一步的每一个特征图利用双线性插值上采样得到原特征图相同的size,然后将原特征图和上采样得到的特征图按channel维进行concatenate。得到的channel是原特征图的channel的两倍,最后再用1×1卷积将channel缩小到原来的channel。最终的特征图和原来的特征图size和channel是一样的。
二、PPM代码

class conv2DBatchNormRelu(nn.Module):
    def __init__(self, in_channels, n_filters, k_size,  stride, padding, bias=True, dilation=1):
        super(conv2DBatchNormRelu, self).__init__()

        if dilation > 1:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 
                                 padding=padding, stride=stride, bias=bias, dilation=dilation)

        else:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 
                                 padding=padding, stride=stride, bias=bias, dilation=1)

        self.cbr_unit = nn.Sequential(conv_mod,
                                      nn.BatchNorm2d(int(n_filters)),
                                      nn.ReLU(inplace=True),)

    def forward(self, inputs):
        outputs = self.cbr_unit(inputs)
        return outputs

class pyramidPooling(nn.Module):

    def __init__(self, in_channels, pool_sizes):
        super(pyramidPooling, self).__init__()

        self.paths = []
        for i in range(len(pool_sizes)):
            self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=False))

        self.path_module_list = nn.ModuleList(self.paths)
        self.pool_sizes = pool_sizes

    def forward(self, x):
        output_slices = [x]
        h, w = x.shape[2:]

        for module, pool_size in zip(self.path_module_list, self.pool_sizes): 
            out = F.avg_pool2d(x, int(h/pool_size), int(h/pool_size), 0)
            out = module(out)
            out = F.upsample(out, size=(h,w), mode='bilinear')
            output_slices.append(out)

        return torch.cat(output_slices, dim=1)

我们可以通过下列代码进行引入:

self.pyramid_pooling = pyramidPooling(2048, [6, 3, 2, 1])
上一篇:提升活动目录域和森林的功能级别


下一篇:MQ发布确认的三种策略,Java高级工程师面试答案大全