Vision Transformer Pruning简记

Vision Transformer Pruning简记

文章目录

参考

剪枝流程

  • 稀疏正则训练
  • 剪枝,减去不重要的部分
  • finetune微调

剪什么?

  • 有关于稀疏训练虽然重要,但是首要还是确定剪什么,在Vision Transformer Pruning中作者剪枝的是Dimension,那么什么是Dimension呢?
    • 我的理解是:Dimension的长度其实就是FC层的输入长度
    • Vision Transformer Pruning简记

怎么剪?

  • 首先回顾一下Transformer
回顾Transformer
  • 我的理解是:以Vision Transformer为例(当然作者用的好像是DeiT-base),我们首先将图片分成16x16个patch,然后进行编码embedding,会出来1xd的向量,对每个patch都这么做然后concat一下就得到了Attention部分的输入, X ∈ R n × d X \in R^{n×d} X∈Rn×d,这里的n就是我们分出来的patch数量

  • 上文提到了输入 X ∈ R n × d X \in R^{n×d} X∈Rn×d,然后一般我的理解是各✖️ W k 、 W q 、 W v W^k、 W^q 、W^v Wk、Wq、Wv权重矩阵得到KQV,然后进入Attention计算(李宏毅老师的说法),然而作者得到KQV的方法好像是FC,这点我查证了一下VIT好像在to_kqv的时候是这么做的,这是没看VIT论文的锅(确信)

  • A t t e n t i o n ( Q K V ) = S o f t m a x ( Q K T d ) V Attention(QKV) = Softmax(\frac{QK^T}{\sqrt d})V Attention(QKV)=Softmax(d ​QKT​)V

  • 对输出进行进行处理:Layer Norm+Residual

  • Y = X + F C o u t ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) Y = X + FC_{out}(Attention(FC_q(X),FC_v(X),FC_v(X)) Y=X+FCout​(Attention(FCq​(X),FCv​(X),FCv​(X))

  • 接下来就是MLP

  • Z = Y + F C 2 ( F C 1 ( Y ) ) Z = Y + FC_2(FC_1(Y)) Z=Y+FC2​(FC1​(Y))

那么剪哪里?
  • 作者给出了一张图,其中右侧是一个Transformer模块,会发现有多个Dimension Pruning,这里我放一张Transformer的encoder做对比就清楚了
  • Vision Transformer Pruning简记

Vision Transformer Pruning简记

  • 我么可以知道两个跳跃链接分别对应encoder中的残差边,于是知道到在MSA(Multi Self-Attention)处有两个Dimension Pruning,MLP处有两个Dimension Pruning,那么这些分别是啥呢,这时候可以参考一下VIT的实现:VIT

    • 先看MLP部分:

    • class FeedForward(nn.Module):
          def __init__(self, dim, hidden_dim, dropout = 0.):
              super().__init__()
              self.net = nn.Sequential(
                  nn.Linear(dim, hidden_dim),
                  nn.GELU(),
                  nn.Dropout(dropout),
                  nn.Linear(hidden_dim, dim),
                  nn.Dropout(dropout)
              )
          def forward(self, x):
              return self.net(x)
      
    • 会发现刚好有两个nn.Linear,对应图中MLP部分的两个Linear(就是FC),然后我么就可以合理推测,所谓Dimension Pruning其实就是减少Linear的输入数量,即减少FC的输入参数

    • 然后回过头来看MSA部分:

    • class Attention(nn.Module):
          def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
              super().__init__()
              inner_dim = dim_head *  heads
              project_out = not (heads == 1 and dim_head == dim)
      
              self.heads = heads
              self.scale = dim_head ** -0.5
      
              self.attend = nn.Softmax(dim = -1)
              self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
      
              self.to_out = nn.Sequential(
                  nn.Linear(inner_dim, dim),
                  nn.Dropout(dropout)
              ) if project_out else nn.Identity()
      
          def forward(self, x):
              qkv = self.to_qkv(x).chunk(3, dim = -1)
              q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
      
              dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
      
              attn = self.attend(dots)
      
              out = torch.matmul(attn, v)
              out = rearrange(out, 'b h n d -> b n (h d)')
              return self.to_out(out)
      
    • 在init部分中也刚好有两个nn.Linear,这也印证了上文的猜想

那么怎么剪?
  • 从这张图其实就很好看出来

    • Vision Transformer Pruning简记
    • 对于FC的输入,作者做了一个Gate,对于那些小于阈值的节点,直接设为0,接下来是怎么得到这些分数,这是剪枝中的核心问题——如何判定某个节点的重要程度。
  • 作者定义如下:

    • 设分数 α ∗ ∈ { 0 , 1 } d \alpha^* \in \{0,1\}^d α∗∈{0,1}d
    • 一个节点是否剪枝可以表示为: X ∗ = X d i a g ( α ∗ ) X^*=Xdiag(\alpha^*) X∗=Xdiag(α∗)
  • 但是这样的离散值无法优化,所以作者将分数松弛了一下,变成了一个连续的数,这样就可以随着梯度下降进行优化:

    • α ^ ∈ R d \hat \alpha \in R^d α^∈Rd
    • 于是X就可以变成: X ^ = X d i a g ( α ^ ) \hat X = Xdiag(\hat \alpha) X^=Xdiag(α^)
    • 然后作者设置一个阈值: α ∗ = α ^ ≥ ζ \alpha^* = \hat \alpha \geq \zeta α∗=α^≥ζ,这样就可以实现评分然后根据评分来剪枝的效果
  • 于是剪枝公式可以表示为

  • X ∗ = P r u n e ( X ) X^* = Prune(X) X∗=Prune(X)

  • 结合上文说到多处剪枝,Transformer的公式可以变为

    • Q , K , V = F C q ′ ( P r u n e ( X ) ) , F C k ′ ( P r u n e ( X ) ) , F C v ′ ( P r u n e ( X ) ) Q,K,V = FC^{'}_q(Prune(X)),FC^{'}_k(Prune(X)),FC^{'}_v(Prune(X)) Q,K,V=FCq′​(Prune(X)),FCk′​(Prune(X)),FCv′​(Prune(X))
    • Y = X + F C o u t ′ ( P r u n e ( A t t e n t i o n ( F C q ( X ) , F C v ( X ) , F C v ( X ) ) ) Y = X + FC^{'}_{out}(Prune(Attention(FC_q(X),FC_v(X),FC_v(X))) Y=X+FCout′​(Prune(Attention(FCq​(X),FCv​(X),FCv​(X)))
    • Z = Y + F C 2 ′ ( P r u n e ( F C 1 ′ ( P r u n e ( Y ) ) ) ) Z = Y + FC^{'}_2({Prune(FC^{'}_1(Prune(Y)))}) Z=Y+FC2′​(Prune(FC1′​(Prune(Y))))
  • 同时公式也侧面说明了剪枝剪哪里

实验部分

  • 这里就不细说了,总之就是效果还行
  • Vision Transformer Pruning简记

作者的总结

  • 作者在文章结尾预测了MSA的M(即head的数量)也可以剪,这和我的预期是相符的,不过目前暂时没有看到这方面的工作,有看到的可以踢我一下。

思考

  • 对于这篇我认为实质上是将FC的剪枝方法用到Transformer中,不可否认效果还成,未来还能剪哪里呢?
  • 关于这点:
    • 首先head的数量确实可以
    • Patch Slimming for Efficient Vision Transformers这篇中提出剪枝Patch想法很新奇
    • 然后我觉得可以参考EfficientNet的思想(多个维度考虑模型大小,这里参考对网络深度的思考)是否可以讨论剪掉整个Endoer,或者只剪掉MSA和MLP保留其中一部分这样?
上一篇:腾讯五十题 No.35 相交链表


下一篇:1016 部分A+B