一 VIT模型
1 代码和模型基础
以timm包为代码基础,VIT模型以vit_base_patch16_224作为模型基础
2 模型结构
2.1 输入的图像
B
∗
3
∗
224
∗
224
B*3*224*224
B∗3∗224∗224,第一步patch_embeding,这里一个patch的对应的像素大小是
16
∗
16
16*16
16∗16,也就是对输入图像作conv2d,对应的kernel_size=16,stride=16,以及升维为768,最终得到输出feature为
B
∗
14
∗
14
∗
768
B*14*14*768
B∗14∗14∗768,然后转化为
B
∗
196
∗
768
B*196*768
B∗196∗768,这里196个patchs其实对应了类似nlp就是196个tokens;
2.2 这里类似nlp,添加了一个起始token,这里用一个可训练的参数torch.nn.Parameter,对应的特征
B
∗
1
∗
768
B*1*768
B∗1∗768,然后和上一步生成的196个tokens合并成197个tokens,对应的特征
B
∗
197
∗
768
B*197*768
B∗197∗768;然后再加上一个位置编码,可训练的参数torch.nn.Parameter,对应的特征
B
∗
197
∗
768
B*197*768
B∗197∗768,相加之后得到后续Block的输入
2.3 这里每个Block对应两块,一个是attention模块,一个是mlp模块;先是attention模块,就是对应的multi-head self attention,输入为
B
∗
197
∗
768
B*197*768
B∗197∗768,先经过Layer_norm,在经过torch.nn.Linear升维为
768
∗
3
768*3
768∗3,这里采用heads为12个,然后reshape成
3
∗
B
∗
12
∗
197
∗
64
3*B*12*197*64
3∗B∗12∗197∗64,然后分别分成q,k,v,每个对应的特征
B
∗
12
∗
197
∗
64
B*12*197*64
B∗12∗197∗64
得到最终attention之后的特征在通过short_cut,加上初始输入的特征,得到最终的输出
B
∗
197
∗
768
B*197*768
B∗197∗768;
2.4 对应的mlp模块,这里主要是对输入先通过Layer_norm,在通过Linear进行升维768*4,然后通过gelu激活函数,加dropout,之后在通过Linear降维成768,在通过dropout,然后将该输出通过short_cut,与初始输入相加得到最终输出
B
∗
197
∗
768
B*197*768
B∗197∗768
2.5 经过多个上述的Block之后,得到输出
B
∗
197
∗
768
B*197*768
B∗197∗768,然后经过Layer_norm,作为最终分类,选取了第一个token作为分类的特征
B
∗
1
∗
768
B*1*768
B∗1∗768,然后进入head阶段,通过Linear得到最终1000类分类
二 Swin Transformer
1 代码和模型基础
以timm包为代码基础,Swin Transformer模型以swin_base_patch4_window7_224作为模型基础;该文章解析可以参https://zhuanlan.zhihu.com/p/360513527
2 模型设计思想
2.1 对于transformer从nlp到cv中的应用,主要调整是视觉图像的scale以及高分辨率问题;针对VIT模型,token数量多,计算self-attention,对应的计算量非常大,所以该模型设计window,只计算该window内部的所有token的self attention降低计算量
对于其中的复杂度计算,这里可以参考卷积的flops计算,第一个计算q,k,v的复杂度,其实就是个Linear的升维操作(参照上一部VIT中计算q,k,v方式),对应的flops就是
c
∗
1
∗
1
∗
3
c
∗
h
∗
w
c*1*1*3c*h*w
c∗1∗1∗3c∗h∗w
2.2 基于window计算的,虽然减少了计算量,但是这样就造成了每个window的视野局限,只能看到当前window内部的token,看不到全局信息,而且每个window之间信息也不能进行交流;针对这两个问题,作者提出了2个解决方案:
a. 第一个就是类似resnet的层级结构 Hierarchical,每个stage后对
2
∗
2
2*2
2∗2组的特征进行merge,同时进行升维(特征空间尺度大小
h
∗
w
→
h
2
∗
h
w
h*w\rightarrow \frac {h}{2}*\frac {h}{w}
h∗w→2h∗wh,特征维度大小
C
→
4
C
→
2
C
C\rightarrow 4C \rightarrow 2C
C→4C→2C),这样每个window感受野就越来越大
b. 就是采用shift window,加强window之间的信息交流
对于shift之后的计算方式可以参考前面的知乎链接,具体的代码实现参考下面代码解析
3 模型结构
3.0 具体代码结构可以参考https://zhuanlan.zhihu.com/p/384514268
3.1 输入的图像
B
∗
3
∗
224
∗
224
B*3*224*224
B∗3∗224∗224,第一步patch_embeding,这里一个patch的对应的像素大小是
4
∗
4
4*4
4∗4,也就是对输入图像作conv2d,对应的kernel_size=4,stride=4,以及升维为128,最终得到输出feature为
B
∗
56
∗
56
∗
128
B*56*56*128
B∗56∗56∗128,然后转化为
B
∗
3136
∗
128
B*3136*128
B∗3136∗128,这里3136个patchs其实对应了类似nlp就是3136个tokens;
3.2 这里没有用到position embeding,因为这里作者采用了relative position bias,发现添加postion embeding对效果有一点损失,所以去掉了这一块,相应的对比试验
3.3 从上一步输入
B
∗
3136
∗
128
B*3136*128
B∗3136∗128,进入Stage1中的SwinTransformerBlock,这里是2个Block交替进行,分别是W-MSA(Window based Self Attention)和SW-MSA(Shift Window based Self Attention)
3.3.1 第一个W-MSA和VIT中的MSA模块基本类似,只是这里面添加了个relative position bias,如下公式
代码如下:
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
对于relative position bias计算方式,对于一个window内部(window_size是M),大小是
M
∗
M
M*M
M∗M,先去计算每个window内每个patch的相对坐标位置;这里生成的relative_coords:
2
∗
M
2
∗
M
2
2*M^2*M^2
2∗M2∗M2,2是分别代表y坐标差(对应行)和x坐标差(对应列),会发现每个维度的坐标差的范围是
[
−
M
+
1
,
M
−
1
]
[-M+1, M-1]
[−M+1,M−1],这里将坐标差转化为正数,所以对于每一个值加上
M
−
1
M-1
M−1,这样对应的坐标差范围是
[
0
,
2
M
−
2
]
[0, 2M-2]
[0,2M−2],刚好是
2
M
−
1
2M-1
2M−1个数,同时对y坐标乘以
2
M
−
1
2M-1
2M−1 ,这样在对x和y坐标差相加之后的范围是
[
0
,
(
2
M
−
2
)
∗
(
2
M
−
1
)
+
(
2
M
−
2
)
]
[0, (2M-2)*(2M-1)+(2M-2)]
[0,(2M−2)∗(2M−1)+(2M−2)],一共是
(
2
M
−
1
)
∗
(
2
M
−
1
)
(2M-1)*(2M-1)
(2M−1)∗(2M−1)个数,刚好对应生成的relative_position_bias_table特征大小是
(
2
M
−
1
)
∗
(
2
M
−
1
)
(2M-1)*(2M-1)
(2M−1)∗(2M−1),可以在这个特征里面找到所有相对位置relative_position_index的值;这里为什么要乘以
2
M
−
1
2M-1
2M−1,应该是个trick,个人猜测,第一个是如果乘以的数太小,会导致圆点坐标的patch与其他patch的坐标差有重复,第二个是乘以
2
M
−
1
2M-1
2M−1刚好可以使生成的特征大小是
(
2
M
−
1
)
∗
(
2
M
−
1
)
(2M-1)*(2M-1)
(2M−1)∗(2M−1),当然乘以2M应该也可以,对应生成的特征大小是
(
2
M
−
2
)
∗
(
2
M
−
1
)
+
2
M
+
1
(2M-2)*(2M-1)+2M+1
(2M−2)∗(2M−1)+2M+1,好像也能满足,具体原因还是不太明白;其中生成的特征relative_position_bias_table,是均值为0,标准差0.02的一组向量;最终的计算,就是在q和k计算attention矩阵之后,在加上根据relative_position_index的位置查找对应在relative_position_bias_table中的值,组成了最终的relative position bias,得到最终的attention矩阵;剩余的步骤与VIT中的一致,swin transformer主要增加了一项relative position bias 替换了VIT原有的position embeding
3.4 进入SW-MSA模块,这里主要是增加了一个shift操作,其余与W-MSA基本操作一样,如图:
在shift之后,从原来的
2
∗
2
2*2
2∗2变成了
3
∗
3
3*3
3∗3个window,为了批量计算,一般想法是padding成每个window同样大小,但是这样就增加了window数量,增加计算量,这里就把a,c,b三块进行移动到右下角,组成了新的
2
∗
2
2*2
2∗2window,但是这样除了左上角第一个的window是完整的不需要改变,其余三个是组成的混合window,不需要使用window内部所有patch的attention,只需要以前划分的
3
∗
3
3*3
3∗3对应window内的patch的attention,举例右下角的window是有A,B的下半部分,C的右半部分,及其他4部分组成,这里只需要计算A内部patch之间的attention,不需要计算A与B或者C的attention,因为A是移动过来的,计算类似图像上下边缘或者左右边缘的关系作用不大,所以加上了一个mask,去选取需要的attention,之后在将移动之后的window在移回去,达到批量快速计算的效果;这里第一步是进行shift操作,主要是通过torch.roll实现;第二步是计算相应的mask,可以参考https://zhuanlan.zhihu.com/p/360513527,如下图
这里计算代码,如下:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
假设输入的是
1
∗
56
∗
56
∗
1
1*56*56*1
1∗56∗56∗1的mask,window_size为7,shift_size为3,产生h_slices和w_slices都是按照三个(0,-7),(-7,-3),(-3,None)进行划分,一共组成了9个块,并分别按照0-8进行标记,然后再进行相减,如果是结果是0,就保留(这表示是同一块,需要计算之间的attention),其余的赋值为-100,不保留;
3.5 在每个stage之后,会先进行Patch Merge操作,对特征进行下采样然后升维的过程;假设是输入特征大小
B
∗
56
∗
56
∗
128
B*56*56*128
B∗56∗56∗128,然后沿着x和y方向间隔一个取特征,分成了个4个
B
∗
28
∗
28
∗
128
B*28*28*128
B∗28∗28∗128的特征,然后cat到一起,得到
B
∗
28
∗
28
∗
512
B*28*28*512
B∗28∗28∗512的特征,之后reshape,加layer norm 在加一个linear 进行降维成256,其实整个Pathc Merge过程就是减小了特征的空间大小,同时增大维度
3.6 经过4个stage(每个stage对应的block数量[2,2,18,2])之后,得到特征
B
∗
49
∗
1024
B*49*1024
B∗49∗1024,经过layernorm,在经过一个平均池化,得到
B
∗
1024
B*1024
B∗1024,然后后面是head阶段,跟着一个linear 分类成1000类,得到最终的结果