Transformer变体(Routing Transformer,Linformer,Big Bird)

本篇博文继续前两篇文章进行整理,前两篇文章传送门:

Transformer变体(Routing Transformer,Linformer,Big Bird)
Efficient Content-Based Sparse Attention with RoutingTransformers
和前两篇博文的目标一样,如何使标准Transformer的时间复杂度降低。Routing Transformer将该问题建模为一个路由问题,目的是让模型学会选择词例的稀疏聚类,所谓的聚类簇是关于每个键和查询的内容的函数,而不仅仅与它们的绝对或相对位置相关。简单来说就是,作用差不多的词是可以变成聚类到一个表示的,这样来加速计算。

如上图与其他模型的对比,图中的每一行代表输出,每一列代表输入,对于a和b图来说,着色的方块代表每一个输出行注意到的元素。对于路由注意力机制来说,不同的颜色代表输出词例的聚类中的成员。具体做法是先用一种公共的随机权重矩阵对键和查询的值进行投影: R = [ Q , K ] [ W R , W R ] T R=[Q,K][W_R,W_R]^T R=[Q,K][WR​,WR​]T然后把R中的向量用k-means聚类成k个簇,然后在每个簇C_i 中加权求和上下文得到嵌入: X i ′ = ∑ j ∈ C k A i j V j X'_i=\sum_{j \in C_k} A_{ij}V_j Xi′​=j∈Ck​∑​Aij​Vj​

最后作者使用了 n \sqrt{n} n ​个簇,所以时间复杂度降维 O ( n n ) O(n \sqrt{n}) O(nn ​)。详细可以看原论文和代码实现:

  • paper:https://arxiv.org/abs/2003.05997
  • code:https://storage.googleapis.com/routing_transformers_iclr20/

Transformer变体(Routing Transformer,Linformer,Big Bird)
Linformer: Self-Attention with Linear Complexity
从O(n^2)到O(n)!首先作者从理论和经验上证明了自注意机制所形成的随机矩阵可以近似为低秩矩阵,所以直接多引入线性投影将原始的缩放点积关注分解为多个较小的关注,也就是说这些小关注组合是标准注意力的低秩因数分解。即如上图,在计算键K和值V时添加两个线性投影矩阵E和F,即 h e a d = A t t e n t i o n ( Q W i Q , E i K W i K , F i V W i V ) = s o f t m a x ( Q W i Q ( E i K W i K ) T d k ⋅ F i V W i V ) head=Attention(QW^Q_i,E_iKW^K_i,F_iVW^V_i)=softmax(\frac{QW^Q_i(E_iKW^K_i)^T}{\sqrt{d_k}}\cdot F_iVW^V_i) head=Attention(QWiQ​,Ei​KWiK​,Fi​VWiV​)=softmax(dk​ ​QWiQ​(Ei​KWiK​)T​⋅Fi​VWiV​)

同时还提供三种层级的参数共享:

  • Headwise: 所有注意力头共享投影句子参数,即Ei=E,Fi=F。
  • Key-Value: 所有的注意力头的键值映射矩阵共享参数同一参数 ,即Ei=Fi=E。
  • Layerwise: 所有层参数都共享。即对于所有层,都共享投射矩阵 E。

完整内容可以看原文,原文有理论证明低秩和分析:

  • paper:https://arxiv.org/abs/2006.04768
  • code:https://github.com/tatp22/linformer-pytorch

Transformer变体(Routing Transformer,Linformer,Big Bird)
Big Bird: Transformers for Longer Sequences
也是采用稀疏注意力机制,将复杂度下降到线性,即 O(N)。如上图,big bird主要包括三个部分的注意力:

  • Random Attention (随机注意力)。如图a,对于每一个 token i,随机选择 r 个 token 计算注意力。
  • Window Attention (局部注意力)。如图b,用滑动窗口表示注意力计算token的局部信息。
  • Global Attention (全局注意力)。如图c,计算全局信息。这些在Longformer中也讲过,可以参考对应论文。

最后把这三部分注意力结合在一起得到注意力矩阵 A,如图 d就是BIGBIRD的结果了,计算公式为: A T T N ( X ) i = x i + ∑ h = 1 H σ ( Q h ( x i ) K h ( X N ( i ) ) T ) ⋅ V h ( X N ( i ) ) ATTN(X)_i=x_i+\sum^H_{h=1} \sigma(Q_h(x_i)K_h(X_{N(i)})^T)\cdot V_h(X_{N(i)}) ATTN(X)i​=xi​+h=1∑H​σ(Qh​(xi​)Kh​(XN(i)​)T)⋅Vh​(XN(i)​)H是头数,N(i)是所有需要计算的token,这里就是由三部分得来的稀疏部分,QKV则是老伙伴了。

  • paper:https://arxiv.org/abs/2007.14062
上一篇:华为HCIA-Routing & Switching之路-5


下一篇:ES基础(四十九)集群内部安全通信