【笔记】TinyBERT(EMNLP2019)
两阶段蒸馏:预训练阶段+finetune阶段
设计3种损失函数分布来适应bert的不同层级的损失计算
- embedding 层输出
- 来自 transformer 层的隐藏层和注意力矩阵
- 输出 logits 的预测层
1. 知识蒸馏的设计
可以将网络的任何一层称为行为函数(
f
f
f , behavior function),KD就是利用小模型(
S
S
S, student)学习大模型(
T
T
T, teacher)。知识蒸馏的数学表示:
L
K
D
=
∑
x
∈
X
L
(
f
S
(
x
)
,
f
T
(
x
)
)
\mathcal{L}_{KD} = \sum_{x \in \mathcal{X}}L(f^S(x), f^T(x))
LKD=x∈X∑L(fS(x),fT(x))
对于Transformer层的蒸馏任务而言,需要学习的就是1)多头自注意力层(Mulit-head attention)、2)全连接前馈网络(fully feed-forward network)以及3)其他中间表示(例如注意力矩阵)。
因此研究的关键在于如何定义有效的1)行为函数和2)损失函数,包括在预训练和finetune阶段。
2. Methods
2.1 Transformer 蒸馏
设「学生模型」共 M M M 层,「学生模型」 N N N 层。
(一)学生层与教师层的对应关系
公式 n = g ( m ) n=g(m) n=g(m) 表示「学生模型」第 m m m 层映射至「教师模型」第 n n n 层,特别的, 0 = g ( 0 ) 0=g(0) 0=g(0) 表示学生的embedding层映射至教师embedding层; N + 1 = g ( M + 1 ) N+1=g(M+1) N+1=g(M+1) 表示预测层相对应。
此知识蒸馏任务可公式化为
L
model
=
∑
x
∈
X
∑
m
=
0
M
+
1
λ
m
L
layer
(
f
m
S
(
x
)
,
f
g
(
m
)
T
(
x
)
)
\mathcal{L}_\text{model} = \sum_{x \in \mathcal{X}}\sum_{m=0}^{M+1} \lambda_{m} L_\text{layer}(f_m^S(x), f_{g(m)}^T(x))
Lmodel=x∈X∑m=0∑M+1λmLlayer(fmS(x),fg(m)T(x))
(二)具体Transmformer层的蒸馏对象
基于attention、基于hidden states,如Figure 2所示。
- Attention based distillation
原因:BERT的attention层能够捕获丰富的语言学知识,包括句法(syntax)和共现关系(coreference information),这些都是自然语言理解的基础。
更明确地说,就是用学生模型学习教师模型的 attention matrices :
L
attn
=
1
h
∑
i
=
1
h
MSE
(
A
i
S
,
A
i
T
)
\mathcal L_\text{attn} = \frac{1}{h}\sum_{i=1}^{h}\text{MSE}(A_i^S, A_i^T)
Lattn=h1i=1∑hMSE(AiS,AiT)
式(3)是注意力矩阵拟合的损失函数表达式。而且论文之间对矩阵 A i A_i Ai 进行拟合,而不是对 s o f t m a x ( A i ) softmax(A_i) softmax(Ai) 实验表明前者的性能更佳且收敛速度更快。
- Hidden states based distillation
L hidn = M S E ( H S W h , H T ) \mathcal{L}_\text{hidn} = MSE(H^SW_h, H^T) Lhidn=MSE(HSWh,HT)
其中 H S ∈ R l × d ′ H^S\in\mathbb{R}^{l \times d'} HS∈Rl×d′ , H T ∈ R l × d H^T \in \mathbb{R}^{l \times d} HT∈Rl×d 分别是学生模型和教师模型的Transformer FFN的隐藏层参数, W h W_h Wh 是一个可学习的线形层,用来将 S S S 对齐至 T T T 。
- 嵌入层蒸馏
L embd = M S E ( E S W e , E T ) \mathcal{L}_\text{embd} = MSE(E^SW_e, E^T) Lembd=MSE(ESWe,ET)
和隐藏层类似。
- 预测/输出层蒸馏
L pred = C E ( z T / t , z S / t ) \mathcal{L}_\text{pred} = CE(z^T/t, z^S/t) Lpred=CE(zT/t,zS/t)
z S z^S zS 和 z T z^T zT 是学生和教师模型的 logits 向量,并对它进行soft,除以 temperature - t t t 。实验表明, t = 1 t=1 t=1 效果最好(没加更好?)。
因此最后的蒸馏任务损失函数就是以上4个的组合:
L
layer
=
{
L
embd
m
=
0
L
hidn
+
L
attn
M
≥
m
≥
0
L
pred
m
=
M
+
1
\mathcal{L}_\text{layer} = \begin{cases} \mathcal{L}_\text{embd} & m=0\\ \mathcal{L}_\text{hidn} + \mathcal{L}_\text{attn} & M\ge m \ge 0\\ \mathcal{L}_\text{pred} & m=M+1\\ \end{cases}
Llayer=⎩⎪⎨⎪⎧LembdLhidn+LattnLpredm=0M≥m≥0m=M+1
(三)TinyBERT Learning
实验设计了两阶段的蒸馏(学习)任务:通用蒸馏 + 特定任务蒸馏,如图1所示。
- General Distillation
通用蒸馏,即在预训练阶段进行蒸馏,它能帮助「学生模型」学习到丰富的embedding知识,有助于提升模型的泛化能力。
预训练阶段的蒸馏任务损失函数(7)不包含公式中的预测层损失函数 L pred \mathcal{L}_\text{pred} Lpred。 旨在让「学生模型」学习模型的中间结构。并且初步的实验表明,在预训练阶段加入预测层损失函数并不能提升下游任务性能。
- Task-specifific Distillation
研究表明现复杂模型在特定领域的任务中存在 **over-parametrization(过度参数化)**的问题,这会造成模型过拟合,泛化性变差。所以,一些参数量小的模型或许能够达到和原来的BERT差不多的效果。
实验中使用了一个 finetuned的BERT模型 + 数据增强来蒸馏TinyBERT。
3. 总结
TinyBERT设计了不同阶段的损失函数,包括对BERT的Embedding层、预测层以及中间层;以及设计了两阶段的蒸馏任务。