『论文笔记』Swin Transformer

https://zhuanlan.zhihu.com/p/361366090

目前transform的两个非常严峻的问题

  1. 受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是Transformer的天生缺陷;
  2. 目前的基于Transformer框架更多的是用来进行图像分类,对实例分割这种密集预测的场景Transformer并不擅长解决。

在Swin Transformer之前的ViT和iGPT,它们都使用了小尺寸的图像作为输入,这种直接resize的策略无疑会损失很多信息。与它们不同的是,Swin Transformer的输入是图像的原始尺寸另外Swin Transformer使用的是CNN中最常用的层次的网络结构,在CNN中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在Swin Transformer中也是满足的。Swin Transformer的这种层次结构,也赋予了它可以像FPN,U-Net等结构实现可以进行分割或者检测的任务。


『论文笔记』Swin Transformer图1:Swin Transformer和ViT的对比

『论文笔记』Swin Transformer

 

 

图2:Swin-T的网络结构 

 在图2中,输入图像之后是一个Patch Partition,再之后是一个Linear Embedding层,这两个加在一起其实就是一个Patch Merging层(至少上面的源码中是这么实现的)。这一部分的源码如下:

class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x) # (1, 48, 3136)
        x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48)
        x = self.linear(x) # (1, 56, 56, 96)
        return x

Patch Merging的作用是对图像进行降采样,类似于CNN中Pooling层。Patch Merging是主要是通过nn.Unfold函数实现降采样的,nn.Unfold的功能是对图像进行滑窗,相当于卷积操作的第一步,因此它的参数包括窗口的大小和滑窗的步长。根据源码中给出的超参我们知道这一步降采样的比例是 
『论文笔记』Swin Transformer
 ,因此经过nn.Unfold之后会得到 
『论文笔记』Swin Transformer
 个长度为 
『论文笔记』Swin Transformer 的特征向量,其中 『论文笔记』Swin Transformer 是输入到这个stage的Feature Map的通道数,第一个stage的输入是RGB图像,因此通道数为3,表示为式(1)。

『论文笔记』Swin Transformer

接着的viewpermute是将得到的向量序列还原到 『论文笔记』Swin Transformer 的二维矩阵,linear是将长度是 『论文笔记』Swin Transformer 的特征向量映射到out_channels的长度,因此stage-1的Patch Merging的输出向量维度是 『论文笔记』Swin Transformer ,对比源码的注释,这里省略了第一个batch为 『论文笔记』Swin Transformer 的维度。

可以看出Patch Partition/Patch Merging起到的作用像是CNN中通过带有步长的滑窗来降低分辨率,再通过 『论文笔记』Swin Transformer 卷积来调整通道数。不同的是在CNN中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的地响应值,而Patch Merging的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。在一些需要提升模型容量的场景中,我们其实可以考虑使用Patch Merging来替代CNN中的池化。

   
上一篇:还在被R语言中的因子factor毒打吗


下一篇:【点宽专栏】期货多因子(二)——各因子描述