swin transformer 核心代码记录

目前更新部分包括swin的基本setting,基本模块,相对位置坐标理解和部分代码展示。

swin 包含了四种setting,依次是tiny,small, base 和 large。可以类比resnet。
swin transformer 核心代码记录

Swin-b 主体部分网络结构 BasicLayer

swin transformer 核心代码记录

结构展示

BasicLayer(
  (blocks): ModuleList(

    (0): SwinTransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      # WindowAttention
      (attn): WindowAttention(
        (qkv): Linear(in_features=128, out_features=384, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
        (softmax): Softmax(dim=-1)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    
    
    (1): SwinTransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): WindowAttention(
        (qkv): Linear(in_features=128, out_features=384, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
        (softmax): Softmax(dim=-1)
      )
      (drop_path): DropPath()
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  
  
  (downsample): PatchMerging(
    (reduction): Linear(in_features=512, out_features=256, bias=False)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

整个流程图

swin transformer 核心代码记录

Vit方式的 non-overlap patch partition 模块

先padding到patch尺寸的整数倍

if W % self.patch_size[1] != 0:
    x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
    x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

最核心的就是使用一个有 stride的 conv代替分 patch操作。

#using a nxn (s=n) conv is equivalent to splitting nxn (no overlap) patches.
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
    self.norm = norm_layer(embed_dim)
else:
    self.norm = None

编码得到的feature就是patch编码得到的了。

划分 windows 模块

得到了patch编码得到的embedding之后,以下面的方式,用window的方式进行划分,不同window放到batch轴上方便快速计算。

# partition windows, nW means number of windows
x_windows = window_partition(
    shifted_x, self.window_size
)  # nW*B, window_size, window_size, C    [392, 12, 12, 128]
x_windows = x_windows.view(
    -1, self.window_size * self.window_size, C
)  # nW*B, window_size*window_size, C 

WindowAttention 模块

swin transformer 核心代码记录

这里是swin 的精髓,作者也是对比了 global 计算affinity和 滑窗计算的复杂度
swin transformer 核心代码记录
这里M是window size,常数。接下来讲解代码模块。

假设一些超参数

self.window_size = window_size  # Wh, Ww, (12, 12)
self.num_heads = num_heads # 4
head_dim = dim // num_heads # 32
self.scale = qk_scale or head_dim ** -0.5 # 0.17

可学习的相对位置编码

创建一个 可学习的embedding, 尺寸为 [(2* Wh-1) * (2* Ww-1), nH] 。为什么尺寸是这样?是因为要look up table 也就是查表法得到某个位置的权重。这里解释一下,因为table需要囊括一个解空间,解空间(2* Wh-1) * (2* Ww-1)这么大,然后作为index,也就是下标,索引 这个relative_position_bias_table。比如两个embedding空间上相距10,那么就需要 找relative_position_bias_table [10], 相距-10,就是relative_position_bias_table [-10]。

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)  # 2*Wh-1 * 2*Ww-1, nH (denotes num_heads)

相对位置表

需要产生一个 12 x 12 的window 的相对坐标编码,思考多大的解空间可以cover住相对位置那?当12x12的window,对于每一行简单用坐标位置差来描述的话,是 [-11, 11],也就是2w-1个值,正负是因为前后的相对位置不是无向的。目标矩阵尺寸是 (2w-1)x(2h-1)。好了,知道我们想干啥了就看代码了。
先得到一个 coords_flatten,尺寸是(2, Wh * Ww),W表示window。

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
          6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,
          7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,
          6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
          0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,
          6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
          0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,
          6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
          0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,
          6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])

然后使用

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww

coords_flatten[:, :, None] 维度是 [2, 144, 1], coords_flatten[:, None, :] 维度是[2, 1, 144]
两个矩阵对应相减, 根据广播规则得到相对的postion。relative_coords 尺寸 [2, 144,144]。广播规则可以这么看,固定coords_flatten[0]的第一个元素0,然后依次与coords_flatten[1] 的每一个元素相减。

tensor([[[  0,   0,   0,  ..., -11, -11, -11],
         [  0,   0,   0,  ..., -11, -11, -11],
         [  0,   0,   0,  ..., -11, -11, -11],
         ...,
         [ 11,  11,  11,  ...,   0,   0,   0],
         [ 11,  11,  11,  ...,   0,   0,   0],
         [ 11,  11,  11,  ...,   0,   0,   0]],

        [[  0,  -1,  -2,  ...,  -9, -10, -11],
         [  1,   0,  -1,  ...,  -8,  -9, -10],
         [  2,   1,   0,  ...,  -7,  -8,  -9],
         ...,
         [  9,   8,   7,  ...,   0,  -1,  -2],
         [ 10,   9,   8,  ...,   1,   0,  -1],
         [ 11,  10,   9,  ...,   2,   1,   0]]])

我们可以粗率计算下,最小是 0-11=-11,最大11-0=11,符合我们的预期,此时索引上面包含了很多负值。此时我们可以通过在每一个方向上+11抵消掉所有负数, 此时最大值是 22,最小值0.

relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1

在swin中,相对位置编码充当了B,也就是计算相似度时候的 bias。为了把上述的二维相对位置矩阵变成一维
swin transformer 核心代码记录,最简单的做法就是 i* (2w-1)+j 的编码方式。swin采用了一种高效的实现。

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
# 下面的轴经过了 permute(1, 2, 0),把 2 放到了最后
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

所以最后的最大值就是 22 x (2x12-1) + 22 = 528, 最小值是0。注意这里w没变是12,而i, j 加了一个偏执11,所以变成了最大值22。

小结

再利用循环移位,就可以达到滑窗的目的。
swin transformer 核心代码记录

上一篇:专访 Swin Transformer 作者胡瀚:面向计算机视觉中的「开放问题」 原创


下一篇:你认为CNN的归纳偏差,Transformer它没有吗?