TinyBERT
TinyBERT: Distilling BERT for Natural Language Understanding
- 对Bert encoding中Transformer进行压缩,使用two-step学习框架在精度允许的范围内节约计算资源和推理速度
Transformer蒸馏
Embedding-layer Distillation
L e m b d = M S E ( E S W e , E T ) E S ∈ R l × d 0 , E T ∈ R l × d l : s e q u e n c e l e n g t h d 0 : s t u d e n t e m b e d d i n g 维 度 d : t e a c h e r e m b e d d i n g 维 度 W e : d 0 × d 可 训 练 的 线 性 变 换 矩 阵 \mathcal{L}_{embd}=MSE(E^{S}W_e,E_T)\\ E^S\in R^{l \times d_0},E^T \in R^{l\times d}\\ l:sequence \quad length\\ d0:student\quad embedding维度\\ d:teacher\quad embedding维度\\ W_e:d_0\times d可训练的线性变换矩阵 Lembd=MSE(ESWe,ET)ES∈Rl×d0,ET∈Rl×dl:sequencelengthd0:studentembedding维度d:teacherembedding维度We:d0×d可训练的线性变换矩阵
Transformer-layer Distillation
- TinyBERT的transformer蒸馏采用隔k层蒸馏的方式
- 如:teacher BERT共有12层,student BERT有4层,那就每隔3层计算一个transformer loss,即student的1对teacher的3,2对6,3对9,4对12,映射函数 g ( m ) = 3 ∗ m g(m)=3*m g(m)=3∗m。这个loss分为两个部分,attention based distillation和hidden states based distillation如下图所示
Attention based loss
L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) A i ∈ R l × l h : a t t e n t i o n 的 头 数 l : 输 入 长 度 A i S : s t u d e n t 网 络 第 i 个 a t t e n t i o n 头 的 a t t e n t i o n s c o r e 矩 阵 A i T : t e a c h e r 网 络 第 i 个 a t t e n t i o n 头 的 a t t e n t i o n s c o r e 矩 阵 \mathcal{L}_{attn}=\frac{1}{ h}\sum_{i=1}^h MSE(A_i^S,A_i^T)\\ A_i\in R^{l\times l}\\ h:attention的头数\\ l:输入长度\\ A_i^S:student网络第i个attention头的attention\quad score矩阵\\ A_i^T:teacher网络第i个attention头的attention\quad score矩阵 Lattn=h1i=1∑hMSE(AiS,AiT)Ai∈Rl×lh:attention的头数l:输入长度AiS:student网络第i个attention头的attentionscore矩阵AiT:teacher网络第i个attention头的attentionscore矩阵
hidden states based distillation
L h i d n = M S E ( H S W h , H T ) H S ∈ R l × d 0 H T ∈ R l × d H S : s t u d e n t t r a n s f o r m e r 的 隐 藏 层 输 出 H T : t e a c h e r t r a n s f o r m e r 的 隐 藏 层 输 出 W h : 投 射 矩 阵 \mathcal{L}_{hidn}=MSE(H^SW_h,H^T)\\ H^S\in R^{l\times d_0}\quad H^T\in R^{l\times d}\\ H^S:student\quad transformer的隐藏层输出\\ H^T:teacher\quad transformer的隐藏层输出\\ W_h:投射矩阵 Lhidn=MSE(HSWh,HT)HS∈Rl×d0HT∈Rl×dHS:studenttransformer的隐藏层输出HT:teachertransformer的隐藏层输出Wh:投射矩阵
Prediction-layer DIstillation
计算 teacher 输出的概率分布和 student 输出的概率分布的 softmax 交叉熵,这里用来模拟teacher在predict层的表现
L
p
r
e
d
=
−
s
o
f
t
m
a
x
(
z
T
)
⋅
l
o
g
_
s
o
f
t
m
a
x
(
z
S
/
t
)
\mathcal{L}_{pred}=-softmax(z^T)\cdot log\_softmax(z^S/t)
Lpred=−softmax(zT)⋅log_softmax(zS/t)
-
t:temperature value,温度越高softmax就越平滑
-
prediction loss有很多变化,可以根据情况改变
loss总结
L m o d e l = ∑ m = 0 M + 1 λ m L l a y e r ( S m , T g ( m ) ) L l a y e r ( S m , T g ( m ) ) = { L e m b d ( S 0 , T 0 ) m = 0 L h i d n ( S m , T g ( m ) ) + L a t t n ( S m , T g ( m ) ) M ≥ m > 0 L p r e d ( S M + 1 , T N + 1 ) m = M + 1 \mathcal{L}_{model}=\sum_{m=0}^{M+1}\lambda_m \mathcal{L}_{layer}(S_m,T_{g(m)})\\ \mathcal{L}_{layer}(S_m,T_{g(m)})= \left\{ \begin{aligned} \mathcal{L}_{embd}(S_0,T_0)&& m=0\\ \mathcal{L}_{hidn}(S_m,T_{g(m)})+\mathcal{L}_{attn}(S_m,T_{g(m)}) && M\geq m > 0\\ \mathcal{L}_{pred}(S_{M+1},T_{N+1}) && m=M+1 \end{aligned} \right. Lmodel=m=0∑M+1λmLlayer(Sm,Tg(m))Llayer(Sm,Tg(m))=⎩⎪⎨⎪⎧Lembd(S0,T0)Lhidn(Sm,Tg(m))+Lattn(Sm,Tg(m))Lpred(SM+1,TN+1)m=0M≥m>0m=M+1
two-step
- 现在general domain数据集上用未微调bert充当teacher蒸馏出一个base模型
- 在具体任务上用微调后的bert重新蒸馏
效果
虽然效果略微下降,但影响不大
推理速度有了明显的提升