文章目录
前言
论文名:Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
论文作者:Zihang Dai et al.
机构:
Carnegie Mellon University(CMU)
Google Brain
期刊/会议名:NAACL 2018
本文作者:XMU_MIAO
日期:2021/1/18
摘要
Transformer有能够学习长依赖的能力,但受限于语言模型中定长上下文设置。我们提出一种新的神经架构
T
r
a
n
s
f
o
r
m
e
r
−
X
L
Transformer{-}XL
Transformer−XL,其能够克服这种定长上下文的限制并且不破坏时间连贯性。 其包含一种
s
e
g
m
e
n
t
−
l
e
v
e
l
segment{-}level
segment−level循环机制和一种新的位置编码方案。本文提出的方法不仅能够捕捉长距离依赖,而且能够解决
c
o
n
t
e
x
t
f
r
a
g
m
e
n
t
i
o
n
context\,\,fragmention
contextfragmention问题。
T
r
a
n
s
f
o
r
m
e
r
−
X
L
Transformer{-}XL
Transformer−XL能够学习的依赖长度比RNNs长80%,比
v
a
n
i
l
l
a
T
r
a
n
s
f
o
r
m
e
r
vanilla\,\,Transformer
vanillaTransformer长450%,在短文本和长文本上都取得了更好性能,并且推理性能比
v
a
n
i
l
l
a
T
r
a
n
s
f
o
r
m
e
r
vanilla\,\,Transformer
vanillaTransformer快1800倍以上。
值的注意的是,我们将在enwiki上的最先进的bpc/perplexity性能提升到了0.99,在text-8上的性能提升到了1.08,在WikiText-103上的性能提升到了18.3,在One Billion Word上的性能提升到了21.8,在Penn Treebank上的性能提升到了54.5(未微调)。 只在WikiText-103上训练时,
T
r
a
n
s
f
o
r
m
e
r
−
X
L
Transformer{-}XL
Transformer−XL能够生成具有数千个tokens的连贯且新颖的文本文章。
1、Introduction & Motivation
RNN-based模型在进行字符级别的语言建模时,大概能够捕捉200个tokens的依赖。vanilla transformer通过在损失函数设计额外的损失项训练深度transformer网络用于字符级别的语言建模,其大幅度优于RNN-based模型。但vanilla transformer存在
c
o
n
t
e
x
t
f
r
a
g
m
e
n
t
a
t
i
o
n
context \,\,fragmentation
contextfragmentation的问题。(vanilla transformer论文讲解参考)另外,vanilla transformer语言模型以划分的文本片段作为输入进行训练,在推理时候则是step-by-step形式,示意图如下:
c
o
n
t
e
x
t
f
r
a
g
m
e
n
t
a
t
i
o
n
context \,\,fragmentation
contextfragmentation:由于vanilla transformer在训练时将长文本划分成包含数百个字符的定长片段并且片段之间没有信息交互,这就造成了transformer在捕捉长距离依赖时受限于片段的长度。除此之外,定长的文本片段在划分时没有尊重文本的语义。
总的来说,transformer-XL解决vanilla transformer在语言建模时候的以下问题:
- 文本片段之间无信息交互造成捕捉长距离依赖的能力不足
- 模型step-by-step的推理模型不够高效
2、How to do ?
本文提出的transformer-XL通过增加片段之间的信息交互来提升模型语言建模的能力。 transformer-XL主要增加了片段级别的循环机制(Segment-Level Recurrence Mechanism)与相对位置编码机制(Relative Positional Encoding),二者的配合不仅增加了模型的语言建模能力更是大大提升了推理的速度,
t
r
a
n
s
f
o
r
m
e
r
−
X
L
transformer{-}XL
transformer−XL训练和推理时示意图如下:
2.1 Segment-Level Recurrence Mechanism
vanilla transformer在进行字符级别的语言模型时,将长文本划分成不进行信息交互的文本片段,并且划分时仅按照固定长度进行划分,没有尊重语义划分时的语义边界。本文提出片段级别的循环机制,增加文本片段之间的信息交互,以捕捉更长距离的依赖。片段级别的循环机制会保存之前文本片段的隐藏层状态并在处理当前文本片段时使用。
h
τ
+
1
n
=
T
r
a
n
s
f
o
r
m
e
r
−
L
a
y
e
r
(
q
τ
+
1
n
,
k
τ
+
1
n
,
v
τ
+
1
n
)
h_{\tau+1}^n=Transformer{-}Layer(q^{n}_{\tau+1},k_{\tau+1}^n,v_{\tau+1}^n)
hτ+1n=Transformer−Layer(qτ+1n,kτ+1n,vτ+1n)
q
τ
+
1
n
,
k
τ
+
1
n
,
v
τ
+
1
n
=
h
τ
+
1
n
−
1
W
q
T
,
h
~
τ
+
1
n
−
1
W
k
T
,
h
~
τ
+
1
n
−
1
W
v
T
q_{\tau+1}^n,k_{\tau+1}^{n},v_{\tau+1}^{n}=h^{n-1}_{\tau+1}W_q^{T},\widetilde{h}^{n-1}_{\tau+1}W_k^T,\widetilde{h}^{n-1}_{\tau+1}W_v^T
qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h
τ+1n−1WkT,h
τ+1n−1WvT
h
~
τ
+
1
n
−
1
=
[
S
G
(
h
τ
n
−
1
)
∘
h
τ
+
1
n
−
1
]
\widetilde{h}^{n-1}_{\tau+1}=[SG(h^{n-1}_{\tau})\circ h^{n-1}_{\tau+1}]
h
τ+1n−1=[SG(hτn−1)∘hτ+1n−1] 其中
S
G
(
∗
)
SG(*)
SG(∗)表示停止计算梯度,
∘
\circ
∘表示按照长度的方向进行拼接。
2.2 Relative Positional Encoding
为了区分片段之间的位置信息,需要对位置编码做一些修改,本文提出相对位置编码克服这一问题。位置编码信息主要在计算
a
t
t
n
s
c
o
r
e
s
attn\,\,scores
attnscores时用到,计算
a
t
t
n
s
c
o
r
e
s
attn\,\,scores
attnscores时的主要计算式为:
A
i
,
j
a
b
s
=
(
W
q
(
E
x
i
+
U
i
)
)
T
W
k
(
E
x
j
+
U
j
)
=
(a)
E
x
i
T
W
q
T
W
k
E
x
j
+
(b)
E
x
i
T
W
q
T
W
k
U
j
+
(c)
U
i
T
W
q
T
W
k
E
x
j
+
(d)
U
i
T
W
q
T
W
k
U
j
A_{i,j}^{abs}=(W_q(E_{x_i}+U_i))^TW_k(E_{x_j}+U_j)\\=\textbf{(a)}E_{x_i}^TW_q^TW_kE_{x_j}+\textbf{(b)}E_{x_i}^TW_q^TW_kU_j+\textbf{(c)}U_i^TW_q^TW_kE_{x_j}+\textbf{(d)}U_i^TW_q^TW_kU_j
Ai,jabs=(Wq(Exi+Ui))TWk(Exj+Uj)=(a)ExiTWqTWkExj+(b)ExiTWqTWkUj+(c)UiTWqTWkExj+(d)UiTWqTWkUj
修改后的相对位置编码为:
A
i
,
j
r
e
l
=
(a)
E
x
i
T
W
q
T
W
k
E
x
j
+
(b)
E
x
i
T
W
q
T
W
k
,
R
R
i
−
j
+
(c)
u
T
W
k
,
E
E
x
j
+
(d)
v
T
W
k
,
R
R
i
−
j
A_{i,j}^{rel}=\textbf{(a)}E_{x_i}^TW_q^TW_kE_{x_j}+\textbf{(b)}E_{x_i}^TW_q^TW_{k,R}R_{i-j}+\textbf{(c)}u^TW_{k,E}E_{x_j}+\textbf{(d)}v^TW_{k,R}R_{i-j}
Ai,jrel=(a)ExiTWqTWkExj+(b)ExiTWqTWk,RRi−j+(c)uTWk,EExj+(d)vTWk,RRi−j
相对于原始transformer中的计算方式,本文提出方法在计算式上主要做了以下几点修改:
- 将(b)和(d)中的 U j U_j Uj变成相对位置编码 R i − j R_{i-j} Ri−j,作者假设这是一种先验,即只有相对位置对注意力机制有影响。 R i − j R_{i-j} Ri−j是original transformer中的固定三角函数编码
- 由于每个查询向量对所有的查询位置是一样的,因而引入可训练的变量 u u u和 v v v代替(c)和(d)中的 U i T W q T U_i^TW_q^T UiTWqT
- 分别设置基于内容的权重矩阵 W k , E W_{k,E} Wk,E和基于位置的权重矩阵 W k , R W_{k,R} Wk,R
在这种修改下,每一项都赋予了一定的可解释意义:(a)基于内容的寻址;(b)与内容相关的位置偏差;(c)全局内容偏差;(d)全局位置偏差;
3、Experiments Analysis(main)
实验结果为在多个字符级别和单词级别的数据集上取得了SOTA的效果,包括WikiText-103、enwik8、text8、One Billion Word和Penn Treebank等数据集。
总结
t r a n s f o r m e r − X L transformer{-}XL transformer−XL改进了 v a n i l l a t r a n s f o r m e r vanilla\,\,transformer vanillatransformer作为语言模型编码时将长文本划分成固定文本片段时造成的文本段之间无信息交互的问题。在 v a n i l l a t r a n s f o r m e r vanilla \,\,transformer vanillatransformer的基础上增加了片段级别的循环机制( S e g m e n t − L e v e l R e c u r r e n c e M e c h a n i s m Segment{-}Level\,\,Recurrence\,\,Mechanism Segment−LevelRecurrenceMechanism)和相对位置编码( R e l a t i v e P o s i t i o n a l E n c o d i n g Relative\,\,Positional\,\,Encoding RelativePositionalEncoding),使得文本片段之间的信息交互增加,并且能够提高模型推理的速度。