【深度学习 十】swin transformer:屠榜各大cv比赛

概要

 Swin transformer: Hierarchical Vision Transformer using Shifted Windows,是微软2021.03.25公布的一篇利用transformer架构处理计算机视觉任务的论文。在图像分割,目标检测各个领域已经霸榜,让很多人看到了transformer完全替代卷积的可能。而且它的设计思想吸取了resnet的精华,从局部到全局,将transformer设计成逐步扩大感受野的工具。

论文链接https://arxiv.org/abs/2103.14030

swin transformer的降采样过程

【深度学习 十】swin transformer:屠榜各大cv比赛

假设图片的大小是224×224的,窗口大小是固定的,7×7。这里每个方框都是一个窗口,每个窗口是固定有7×7个patch,但是patch的大小是不固定的,它会随着patch merging的操作而发生变化。我们把周边4个窗口的patch拼在一起,相当于patch扩大了2×2倍,从而得到8×8大小的patch。

经过这一系列的操作之后,patch的数目在变少,最后整张图只有一个窗口,7x7个patch。所以我们可以认为降采样是指让patch的数量减少,但是patch的大小在变大。

【深度学习 十】swin transformer:屠榜各大cv比赛

这是对ViT的一个改进,ViT从头至尾都是对全局做self-attention,而swin-transformer是一个窗口在放大的过程,然后self-attention的计算是以窗口为单位去计算的,这样相当于引入了局部聚合的信息,和CNN的卷积过程很相似,就像是CNN的步长和卷积核大小一样,这样就做到了窗口的不重合,区别在于CNN在每个窗口做的是卷积的计算,每个窗口最后得到一个值,这个值代表着这个窗口的特征。而swin transformer在每个窗口做的是self-attention的计算,得到的是一个更新过的窗口,然后通过patch merging的操作,把窗口做了个合并,再继续对这个合并后的窗口做self-attention的计算。

【深度学习 十】swin transformer:屠榜各大cv比赛

 Swin-transformer是怎么把复杂度降低的呢? Swin Transformer Block这个模块和普通的transformer的区别就在于W-MSA,而它就是降低复杂度计算的大功臣。

我们假设已知MSA的复杂度是图像大小的平方,根据MSA的复杂度,我们可以得出A的复杂度是(3×3)²,最后复杂度是81。Swin transformer是在每个local windows(红色部分)计算self-attention,根据MSA的复杂度我们可以得出每个红色窗口的复杂度是1×1的平方,也就是1的四次方。然后9个窗口,这些窗口的复杂度加和,最后B的复杂度为9。

整体架构

【深度学习 十】swin transformer:屠榜各大cv比赛

 整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始降低图片分辨率。
  • 而Block具体结构如右图所示,主要是LayerNormMLPWindow AttentionShifted Window Attention组成

Window Attention

这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

先看下公式

【深度学习 十】swin transformer:屠榜各大cv比赛

 主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。

 Shifted Window Attention

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

【深度学习 十】swin transformer:屠榜各大cv比赛

 左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。

在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。

【深度学习 十】swin transformer:屠榜各大cv比赛

 左图是Shift Window Attention,将左面一列和上面一行窗口移到右下拼接在一起,通过mask遮掩,就拼接成了4个window。

Attention Mask

我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果。

【深度学习 十】swin transformer:屠榜各大cv比赛

 上右图中的灰线就是被mask的地方。

Transformer Block整体架构

【深度学习 十】swin transformer:屠榜各大cv比赛

 两个连续的Block架构如上图所示,需要注意的是一个Stage包含的Block个数必须是偶数,因为需要交替包含一个含有Window Attention的Layer和含有Shifted Window Attention的Layer。

整体流程如下

  • 先对特征图进行LayerNorm
  • 通过self.shift_size决定是否需要对特征图进行shift
  • 然后将特征图切成一个个窗口
  • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  • 将各个窗口合并回来
  • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
  • 做dropout和残差连接
  • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

 实验结果

 1.ImageNet-1K的图像分类

 【深度学习 十】swin transformer:屠榜各大cv比赛

 2. COCO的目标检测

【深度学习 十】swin transformer:屠榜各大cv比赛

上一篇:函数的嵌套+nonlocal和global关键字(重点)


下一篇:bin 文件读取储存