Vision Transformer Pruning简记
文章目录
参考
剪枝流程
- 稀疏正则训练
- 剪枝,减去不重要的部分
- finetune微调
剪什么?
- 有关于稀疏训练虽然重要,但是首要还是确定剪什么,在Vision Transformer Pruning中作者剪枝的是Dimension,那么什么是Dimension呢?
- 我的理解是:Dimension的长度其实就是FC层的输入长度
怎么剪?
- 首先回顾一下Transformer
回顾Transformer
-
我的理解是:以Vision Transformer为例(当然作者用的好像是DeiT-base),我们首先将图片分成16x16个patch,然后进行编码embedding,会出来1xd的向量,对每个patch都这么做然后concat一下就得到了Attention部分的输入, X ∈ R n × d X \in R^{n×d} X∈Rn×d,这里的n就是我们分出来的patch数量
-
上文提到了输入 X ∈ R n × d X \in R^{n×d} X∈Rn×d,然后一般我的理解是各✖️ W k 、 W q 、 W v W^k、 W^q 、W^v Wk、Wq、Wv权重矩阵得到KQV,然后进入Attention计算(李宏毅老师的说法),然而作者得到KQV的方法好像是FC,这点我查证了一下VIT好像在to_kqv的时候是这么做的,这是没看VIT论文的锅(确信)
-
A t t e n t i o n ( Q K V ) = S o f t m a x ( Q K T d ) V Attention(QKV) = Softmax(\frac{QK^T}{\sqrt d})V Attention(QKV)=Softmax(d QKT)V
-
对输出进行进行处理:Layer Norm+Residual
-
Y = X + F C o u t ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) Y = X + FC_{out}(Attention(FC_q(X),FC_v(X),FC_v(X)) Y=X+FCout(Attention(FCq(X),FCv(X),FCv(X))
-
接下来就是MLP
-
Z = Y + F C 2 ( F C 1 ( Y ) ) Z = Y + FC_2(FC_1(Y)) Z=Y+FC2(FC1(Y))
那么剪哪里?
- 作者给出了一张图,其中右侧是一个Transformer模块,会发现有多个Dimension Pruning,这里我放一张Transformer的encoder做对比就清楚了
-
我么可以知道两个跳跃链接分别对应encoder中的残差边,于是知道到在MSA(Multi Self-Attention)处有两个Dimension Pruning,MLP处有两个Dimension Pruning,那么这些分别是啥呢,这时候可以参考一下VIT的实现:VIT
-
先看MLP部分:
-
class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x)
-
会发现刚好有两个nn.Linear,对应图中MLP部分的两个Linear(就是FC),然后我么就可以合理推测,所谓Dimension Pruning其实就是减少Linear的输入数量,即减少FC的输入参数
-
然后回过头来看MSA部分:
-
class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out)
-
在init部分中也刚好有两个nn.Linear,这也印证了上文的猜想
-
那么怎么剪?
-
从这张图其实就很好看出来
- 对于FC的输入,作者做了一个Gate,对于那些小于阈值的节点,直接设为0,接下来是怎么得到这些分数,这是剪枝中的核心问题——如何判定某个节点的重要程度。
-
作者定义如下:
- 设分数 α ∗ ∈ { 0 , 1 } d \alpha^* \in \{0,1\}^d α∗∈{0,1}d
- 一个节点是否剪枝可以表示为: X ∗ = X d i a g ( α ∗ ) X^*=Xdiag(\alpha^*) X∗=Xdiag(α∗)
-
但是这样的离散值无法优化,所以作者将分数松弛了一下,变成了一个连续的数,这样就可以随着梯度下降进行优化:
- α ^ ∈ R d \hat \alpha \in R^d α^∈Rd
- 于是X就可以变成: X ^ = X d i a g ( α ^ ) \hat X = Xdiag(\hat \alpha) X^=Xdiag(α^)
- 然后作者设置一个阈值: α ∗ = α ^ ≥ ζ \alpha^* = \hat \alpha \geq \zeta α∗=α^≥ζ,这样就可以实现评分然后根据评分来剪枝的效果
-
于是剪枝公式可以表示为
-
X ∗ = P r u n e ( X ) X^* = Prune(X) X∗=Prune(X)
-
结合上文说到多处剪枝,Transformer的公式可以变为
- Q , K , V = F C q ′ ( P r u n e ( X ) ) , F C k ′ ( P r u n e ( X ) ) , F C v ′ ( P r u n e ( X ) ) Q,K,V = FC^{'}_q(Prune(X)),FC^{'}_k(Prune(X)),FC^{'}_v(Prune(X)) Q,K,V=FCq′(Prune(X)),FCk′(Prune(X)),FCv′(Prune(X))
- Y = X + F C o u t ′ ( P r u n e ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) ) Y = X + FC^{'}_{out}(Prune(Attention(FC_q(X),FC_v(X),FC_v(X))) Y=X+FCout′(Prune(Attention(FCq(X),FCv(X),FCv(X)))
- Z = Y + F C 2 ′ ( P r u n e ( F C 1 ′ ( P r u n e ( Y ) ) ) ) Z = Y + FC^{'}_2({Prune(FC^{'}_1(Prune(Y)))}) Z=Y+FC2′(Prune(FC1′(Prune(Y))))
-
同时公式也侧面说明了剪枝剪哪里
实验部分
- 这里就不细说了,总之就是效果还行
作者的总结
- 作者在文章结尾预测了MSA的M(即head的数量)也可以剪,这和我的预期是相符的,不过目前暂时没有看到这方面的工作,有看到的可以踢我一下。
思考
- 对于这篇我认为实质上是将FC的剪枝方法用到Transformer中,不可否认效果还成,未来还能剪哪里呢?
- 关于这点:
- 首先head的数量确实可以
- Patch Slimming for Efficient Vision Transformers这篇中提出剪枝Patch想法很新奇
- 然后我觉得可以参考EfficientNet的思想(多个维度考虑模型大小,这里参考对网络深度的思考)是否可以讨论剪掉整个Endoer,或者只剪掉MSA和MLP保留其中一部分这样?