摘要
从头开始训练深层 transformers需要大型数据集是一个普遍观点。因此,对于小型数据集,人们通常在微调期间,在预训练模型上使用较浅和简单的额外层。本项工作表明,这种情况并不是常见的:只需通过正确的初始化和优化,非常深的transformers的优势就可以转移到具有小型数据集的小型任务,包括Text-to-SQL语义解析和阅读理解。特别是,我们成功训练了48层的transformers,包括来自预训练RoBERT的24层网络和需从头开始训练的24层网络。通过较少的训练步骤,无需特定于任务的预训练,我们就可以在具有挑战性的跨领域的Text-to-SQL解析benchmark Spider上达到最好的性能。受prior T-Fixup工作的启发,我们通过Data-dependent Transformer Fixed-update初始化技术(DT-Fixup)来达到该性能。进一步的误差分析表明,增加网络深度可以帮助改善小型数据集中需要依赖推理和结构化理解的难例的泛化能力。
1.介绍
近年来,使用transformers训练的大型预训练语言模型已成为现代NLP系统的标准构建块,以帮助改善在特定任务标注的数据集限制上模型的泛化能力。在实践中,已经发现,更深的transformers通常会产生更好的数据训练结果,特别是涉及到关于推理和结构化理解的任务。这表明额外的transformers层应与预训练的模型一起使用,而不是仅添加较浅和简单的神经网络层,例如分类器头,其目前被用在各种NLP任务的模型中。然而,文献中的共同观点是,从头开始训练深层transformers需要大量数据集,并且据我们所知,研究人员在小数据集中仅进行了很少的尝试。一种观点是,虽然原则上在预训练模型上使用额外transformers层应该有助于具有挑战性的问题,但由于训练数据有限,它在实践中起不到多大作用。我们展示通过使用本工作中所提出的方法,能够解决若干优化问题,即使在小型数据集上也可以训练非常深的transformers以改进泛化能力。
预训练模型的一个优点是在小型数据集上微调时所需的计算资源较少。例如,它允许使用者在单个GPU上进行finetune,并在下游任务上获得强大的性能。然而,大尺寸的预训练模型限制了可以用于训练新transformer层的batch size。尽管应用泛,但训练transformer模型是很困难的。标准transformer训练方法中,通过利用学习率warm-up,层归一化和大batch size,并且模型通常无法在缺少这些组件中的任何一个情况下还能很好地学习。受限制的batch size加剧了训练的困难程度。甚至使用大的batch size,通常会观察到较差的泛化结果,特别是当数据集大小仅为比batch size的几倍时。此外,许多最近的工作注意到这种训练方法中会由于使用层归一化而造成性能差距。
受Huang et al. (2020) 最近 T-Fixup 工作的启发,该工作消除了训练vanilla transformers需要学习速率warm-up和层归一化的需求,我们通过应用不同的分析来解决T-Fixup中的几个关键限制来得到data-dependent的初始化策略。我们将提出的方法称为Data-dependent Transformer Fixed-update初始化方案,即DT-Fixup。在预训练模型顶部使用额外的尚未训练的transformers层,DT-Fixup能够训练明显更深的transformers,并且也适用于不同的神经结构架构。我们的结果也从Vanilla transformers扩展到具有关系编码的transformers,这允许我们将结果应用于一个名为关系感知transformers的一个变体。通过在不同的任务上应用DT-Fixp,我们表明深度transformers不适用于小型数据集,是因为优化过程,而不是网络结果。通过适当的初始化和优化,训练额外的transformers层能促进数据中复杂关系和结构的学习。
我们在Spider和ReColr benchmark上验证了DT-Fixup的有效性。虽然Text-to-SQL语义解析本质上与阅读理解有所不同,但它们共享类似的特征,即这些数据都需要某些推理和结构化理解能力。同时,两个数据集的尺寸均小于10K的训练样本。
在两个数据集中,DT-Fixup始终优于使用更好泛化能力的标准方法,并允许训练更深的transformer模型。对于Spider,我们成功地应用DT-Fixup,以训练包含48个transformer层的Text-to-SQL解析器,其中前24层使用的是RoBERT预训练模型,后24层则是relation-aware层。我们的解析器在Spider测试集上实现了70.9%的精确匹配准确性,这是目前最好的结果。与现有技术相比,它仅需要较少的训练步骤,并且没有特定于任务的预训练。对于ReClor,我们通过简单地在Roberta上添加4个transformer层,我们在排行榜上排名第二。进一步的误差分析表明,通过增加网络深度导致的性能改善主要来自对需要推理和结构化理解数据的更好的泛化能力。甚至使用深层模型导致的失败预测也比浅层模型更合理。
2.背景
在本节中,为了提出必要的背景知识,我们首先介绍关系感知transformer层,其引入了额外的归纳偏压,以优于具有有限数据的vanilla transformer层。然后,我们介绍了用于优化更深的vanilla transformers的T-Fixup技术,并讨论为什么它不能直接适用于混合transformer的优化设置。
2.1 Relative Position and Relational Encodings in Transformers
考虑一组输入
X
=
[
x
1
,
.
.
.
,
x
n
]
X=[\textbf x_1,...,\textbf x_n]
X=[x1,...,xn],其中
x
i
∈
R
d
x
\textbf xi∈\mathbb R^{d_x}
xi∈Rdx。由Vaswani et al. (2017) 提出的transformer,是多个block组成的堆栈,每个block包括多头自注意力层,层归一化,多层感知器和跳过连接。每个block(为了简单描述,self-attention中只使用单个头)将每个
x
i
\textbf x_i
xi转换为
y
i
∈
R
d
x
\textbf y_i∈\mathbb R^{d_x}
yi∈Rdx,如下所示:
α
i
j
=
s
o
f
t
m
a
x
(
x
i
q
(
x
j
j
)
/
d
z
)
(1)
\alpha_{ij}=softmax(\textbf x_i\textbf q(\textbf x_j\textbf j)/\sqrt{d_z})\tag{1}
αij=softmax(xiq(xjj)/dz
)(1)
z
i
=
∑
j
=
1
n
α
i
j
x
j
v
;
(2)
z_i=\sum^n_{j=1}\alpha_{ij}\textbf x_j\textbf v;\tag{2}
zi=j=1∑nαijxjv;(2)
y
~
i
=
L
a
y
e
r
N
o
r
m
(
x
i
+
z
i
w
T
)
(3)
\tilde {\textbf y}_i=LayerNorm(\textbf x_i+\textbf z_i\textbf w^T)\tag{3}
y~i=LayerNorm(xi+ziwT)(3)
y
~
i
=
L
a
y
e
r
N
o
r
m
(
y
~
+
M
L
P
(
y
~
i
)
)
(4)
\tilde{\textbf y}_i=LayerNorm(\tilde{\textbf y}+MLP(\tilde{\textbf y}_i))\tag{4}
y~i=LayerNorm(y~+MLP(y~i))(4)
其中,在索引
j
j
j上应用softmax操作,MLP是双层的感知器, LayerNorm是层归一化层,并且
q
,
k
,
v
∈
R
d
x
×
d
z
,
w
∈
R
d
x
×
d
z
\textbf q,\textbf k,\textbf v∈\mathbb R^{d_x×d_z},\textbf w∈\mathbb R^{d_x×d_z}
q,k,v∈Rdx×dz,w∈Rdx×dz。
为了使transformer偏向输入之间的一些预先存在的关系特征,Shaw et al. (2018) 提出了通过改变等式1-2来表示自注意力层中相对位置信息的方法,如下所示:
α
i
j
=
s
o
f
t
m
a
x
(
x
i
q
(
x
j
k
+
r
i
j
k
)
d
z
)
z
i
=
∑
j
=
1
n
α
i
j
(
x
j
v
+
r
i
j
v
)
(5)
\begin{array}{cl} \alpha_{ij}=softmax\bigg(\frac{\textbf x_i\textbf q(\textbf x_j\textbf k+\textbf r^k_{ij})}{\sqrt{d_z}}\bigg)\\ \textbf z_i=\sum^n_{j=1}\alpha_{ij}(\textbf x_j\textbf v+\textbf r^{v}_{ij}) \end{array}\tag{5}
αij=softmax(dz
xiq(xjk+rijk))zi=∑j=1nαij(xjv+rijv)(5)
其中,
r
i
j
∈
R
d
z
\textbf r_{ij}∈\mathbb R^{d_z}
rij∈Rdz项在输入中编码两个元素
x
i
x_i
xi和
x
j
x_j
xj之间的已知关系。Wang et al. (2019a) 修改了此框架,以有效地使用
r
i
j
r_{ij}
rij为Text-to-SQL解析器编码模式信息,并称为relation-aware transformer (RAT)。
2.2 T-Fixup and its Limitations
Huang et al. (2020) 发现,在transformers的早期训练期间对warm-up的需求主要来自于Adam优化器中的高方差和通过层标归一化反向传播的综合影响。限制梯度更新将减少方差并进行稳定训练,这可以通过适当地初始化模型权重来实现。
他们为 vanilla transformer 推导得到一个被称为T-Fixup的权重初始化方案,该方案完全消除了对层归一化和学习率warm-up的需求,从而稳定训练以避免泛化性能差的有害影响。T-Fixup 要求输入
x
\textbf x
x是由方差为
d
−
1
2
d^{-\frac{1}{2}}
d−21的高斯随机初始化的嵌入,其中
d
d
d是嵌入维度。然后,输入和编码器参数,即由等式1-4中定义的vanilla自注意力块中的
x
,
v
,
w
\textbf x,\textbf v,\textbf w
x,v,w,以及MLP块中的权重矩阵,都通过乘以
0.67
N
−
1
4
0.67N^{-\frac{1}{4}}
0.67N−41以重新缩放,其中
N
N
N是transformer层的数量。
但是,T-Fixup有两种限制从而缩小了其应用范围。首先,T-Fixup仅为Vanilla transformer设计,而无法用于其他变体,如先前描述的相对位置或关系感知版本。其次,他们做出了一个重要假设,即输入
x
\textbf x
x可以*初始化,然后将其缩放到与
v
,
w
\textbf v,\textbf w
v,w和MLP权重相同的量级。 这使得该方法不可用于混合设置,其中未训练的transformer层的输入取决于预训练模型的输出。第一个问题可以通过重新推导T-Fixup方法的缩放因子来解决,并考虑额外的关系项。然而,为了解决第二个限制需要改变对分析的假设和更剧烈的修改。
3.方法
我们现在遵循T-Fixup的分析框架,并在现有预训练的模型基础上进行推导以限制自注意力块的梯度更新。基于推导,我们提出了一种数据相关的初始化策略,用于在预训练的编码上添加新transformers层的混合设置。
3.1 Applicable Architectures
我们的分析适用于图1中所示的一般架构类型,其中输入通过 pre-transformer,main transformer和post-transformer模块,然后得到输出。pre和post Transformer模块可以是任何可用Adam稳定训练的架构,包括MLP,LSTM,CNN或预训练的深层transformer模块,其可以稳定地fine-tuned,其学习率明显小于于main transformer模块的学习率。对于这项工作,为简单起见,我们将考虑编码器仅包含main transformer的情况,而我们的解码器是LSTM,可以将其视为post-transformer模块的一部分。与Huang et al. (2020) 的框架类似,我们将分析扩展到包括深层transformer解码器的情况。
我们使用
f
e
f_e
fe来表示 pre-transformer 模块(
e
e
e表示预训练的编码器),其参数表示为
θ
e
\textbf θ_e
θe;类似地,具有参数
θ
o
\textbf θ_o
θo的
f
o
f_o
fo表示 post-transformer 模块(
o
o
o表示输出)。main-transformer模块
f
G
f_G
fG是
L
L
L个transformer block组成的堆栈,每个block由自注意力块和MLP块组成。令
G
l
,
l
=
1
,
.
.
.
,
2
N
G_l,l=1,...,2N
Gl,l=1,...,2N表示块中独立的自注意力或MLP层(
G
l
G_l
Gl不包括跳过连接),其参数为
θ
l
\textbf θ_l
θl并令
L
=
2
N
L=2N
L=2N,
f
G
f_G
fG的参数由
θ
G
=
⋃
l
=
1
L
θ
l
\textbf θ_G=\bigcup^L_{l=1}\textbf θ_l
θG=⋃l=1Lθl表示。
3.2 Theoretical Results for Stable Update
略,理论推导可参考论文
3.3 Proposed Method: DT-Fixup
与先前的工作不同,适当的初始化是不足以确保等式7和8在训练的早期阶段稳定。这是由于输入 x \textbf x x通常取决于预训练模型的权重,而不是自己初始化的。经验上,我们观察到输入范数 ∣ ∣ x ∣ ∣ ||\textbf x|| ∣∣x∣∣在整个训练中相对稳定,但难以通过重新缩放直接控制。基于此观察,我们将 ∣ ∣ x ∣ ∣ ||\textbf x|| ∣∣x∣∣视为常数并通过前向通过所有训练y样例来估计,令其为 μ = m a x j [ ∣ ∣ x j ∣ ∣ ] μ=max_j [||\textbf x_j||] μ=maxj[∣∣xj∣∣]。然后,我们在定理3.1的因式中使用此估计 μ μ μ以获得初始化所需的缩放。由于所有层的参数都初始化为相同的比例,因此在本节中,我们将索引 l l l丢弃。在实践中,预训练模型 μ μ μ大约为10,因此 v , w v,w v,w和 r i v r^v_i riv自然是小两个数量级。DT-Fixup描述如下:
- 除了从预训练模型中加载的参数以外,在其余所有参数上应用Xavier初始化;
- 除了预训练的transformer,移除其余层中的学习速率warm-up 和层归一化;
- 在所有训练样例上前向传递得到最大输入范式 μ = m a x j [ ∣ ∣ x j ∣ ∣ ] μ=max_j[||\textbf x_j||] μ=maxj[∣∣xj∣∣];
- 在每个transformer层内部,对于关系感知transformer层,将MLP块中的权重矩阵和注意块中的 v , w , r v v,w,r^v v,w,rv缩放 ( N ∗ ( 4 μ 2 + 2 μ + 2 ) ) − 1 2 (N*(4μ^2+2μ+ 2))^{-\frac{1}{2}} (N∗(4μ2+2μ+2))−21;对于vanilla transformer层,将MLP块中的权重矩阵和注意块中的 v , w , r v v,w,r^v v,w,rv缩放 N − 1 2 / ( 2 μ ) N^{-\frac{1}{2}}/(2μ) N−21/(2μ)。