https://zhuanlan.zhihu.com/p/361366090
目前transform的两个非常严峻的问题
- 受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是Transformer的天生缺陷;
- 目前的基于Transformer框架更多的是用来进行图像分类,对实例分割这种密集预测的场景Transformer并不擅长解决。
在Swin Transformer之前的ViT和iGPT,它们都使用了小尺寸的图像作为输入,这种直接resize的策略无疑会损失很多信息。与它们不同的是,Swin Transformer的输入是图像的原始尺寸另外Swin Transformer使用的是CNN中最常用的层次的网络结构,在CNN中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在Swin Transformer中也是满足的。Swin Transformer的这种层次结构,也赋予了它可以像FPN,U-Net等结构实现可以进行分割或者检测的任务。
图1:Swin Transformer和ViT的对比
图2:Swin-T的网络结构
在图2中,输入图像之后是一个Patch Partition,再之后是一个Linear Embedding层,这两个加在一起其实就是一个Patch Merging层(至少上面的源码中是这么实现的)。这一部分的源码如下:
class PatchMerging(nn.Module): def __init__(self, in_channels, out_channels, downscaling_factor): super().__init__() self.downscaling_factor = downscaling_factor self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0) self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels) def forward(self, x): b, c, h, w = x.shape new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor x = self.patch_merge(x) # (1, 48, 3136) x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48) x = self.linear(x) # (1, 56, 56, 96) return x
Patch Merging的作用是对图像进行降采样,类似于CNN中Pooling层。Patch Merging是主要是通过nn.Unfold
函数实现降采样的,nn.Unfold
的功能是对图像进行滑窗,相当于卷积操作的第一步,因此它的参数包括窗口的大小和滑窗的步长。根据源码中给出的超参我们知道这一步降采样的比例是
,因此经过nn.Unfold
之后会得到
个长度为的特征向量,其中 是输入到这个stage的Feature Map的通道数,第一个stage的输入是RGB图像,因此通道数为3,表示为式(1)。
接着的view
和permute
是将得到的向量序列还原到 的二维矩阵,linear
是将长度是 的特征向量映射到out_channels
的长度,因此stage-1的Patch Merging的输出向量维度是 ,对比源码的注释,这里省略了第一个batch为 的维度。
可以看出Patch Partition/Patch Merging起到的作用像是CNN中通过带有步长的滑窗来降低分辨率,再通过 卷积来调整通道数。不同的是在CNN中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的地响应值,而Patch Merging的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。在一些需要提升模型容量的场景中,我们其实可以考虑使用Patch Merging来替代CNN中的池化。