https://arxiv.org/abs/2107.00967
在一次分享中看到这篇论文,感觉有意思细读了一下
主要是讲基于可微分树的递归transformer来实现具有强解释性的层次预训练语言模型
论文主要章节涉及了三个方面
- 模型算法,讲解借助transformer实现对句子树结构的提取
- 算法复杂度的优化,相比于之前提出的tree-LSTM是 n 3 n^3 n3复杂度降低到了线性复杂度
- 在以上基础上进行大语料的预训练
相关背景知识
乔姆斯基范式(CNF,Chomsky Normal Form)
任何语法都可以转化成一个弱等价的CNF形式,CNF语法都是二分叉
CYK算法
CYK算法(也称为Cocke–Younger–Kasami算法)是一种用来对 上下文无关文法(CFG,Context Free Grammar)进行语法分析(parsing)的算法。该算法最早由John Cocke, Daniel Younger and Tadao Kasami分别独立提出,其中John Cocke还是1987年度的图灵奖得主。CYK算法是基于动态规划思想设计的一种自底向上语法分析算法。
看过最易懂的博文
代码实现
2. Gumbel-Softmax estimation
在自底向上的计算过程中,每个格子会有多种组合方式,在各种组合方式中,选择概率最大的组合,即argmax函数。但是argmax函数是不可导的,没有办法反向传播。
通过reparameterization对logits的输出拟合为onehot,同时保证梯度可以反向传播
对离散变量再参数化
4. 基于大语料的预训练语言模型的大概套路
模型结构设计
Differentiable Tree
该论文定义了一个类似于CKY形式的可微二叉树解析器
句子 S={s1,s2,s3,…sn}
如上图,每一个格子
T
(
i
,
j
)
=
<
e
i
,
j
,
p
i
,
j
,
p
~
i
,
j
>
\Tau(i,j)=<e_{i,j},p_{i,j},\tilde{p}_{i,j}>
T(i,j)=<ei,j,pi,j,p~i,j>
e
i
,
j
e_{i,j}
ei,j 是向量表征
p
i
,
j
p_{i,j}
pi,j 是每一个步所有组合的概率
p
~
i
,
j
\tilde{p}_{i,j}
p~i,j是在[
s
i
s_i
si,
s
j
s_j
sj]的子树的概率
树的末端节点是
T
i
,
i
\Tau_{i,i}
Ti,i,
e
i
,
i
e_{i,i}
ei,i以当前输入
s
i
s_i
si的向量初始化,
p
i
,
j
p_{i,j}
pi,j 和
p
~
i
,
j
\tilde{p}_{i,j}
p~i,j初始化为1。
上述公式的k是指(
s
i
s_i
si,
s
j
−
1
s_{j-1}
sj−1)之间的某一分割点(分割点不同,会对应出不同的组合)
第一个公式
f
(
.
)
f(.)
f(.)是我们下一节Recursive Transformer定义的函数,
p
i
,
j
k
p_{i,j}^k
pi,jk 和
p
~
i
,
j
k
\tilde{p}_{i,j}^k
p~i,jk分别指一步中组合的概率和其子树的概率
第二个公式
以K为分割点的子树的概率,是当前组合的概率和左右子树概率的乘积,这个和CKY算法是一致的
第三个公式
这里放一个链接 Straight Through Gumbel-Softmax ,通过一定方式实现argmax函数的可微??
p
i
,
j
p_{i,j}
pi,j 和
p
~
i
,
j
\tilde{p}_{i,j}
p~i,j是基于所有分割点得到的
p
i
,
j
k
p_{i,j}^k
pi,jk 和
p
~
i
,
j
k
\tilde{p}_{i,j}^k
p~i,jk的组合
output: 计算得出权重
第四个公式
通过当前组合与权重系数的乘积计算出
e
i
,
i
e_{i,i}
ei,i
第五个公式
通过概率向量与权重系数的乘积计算出新的概率向量
Recursive Transformer
这个图对应了上一节第一个公式。
中间shape的转换过程看图,不想转述了,最终输出的
p
i
,
j
p_{i,j}
pi,j是
R
1
R^1
R1,
c
i
,
j
k
c_{i,j}^k
ci,jk是
R
d
R^d
Rd
Tree Recovery
通过Straight-Through Gumbel-Softmax在每一个cell选择最佳的分割点,Tree( T 1 , n \Tau_{1,n} T1,n), 从树的根节点自顶向下递归操作,选择的最佳分割点还原树的结构,类似于CKY算法最后的回溯过程
Complexity Optimization 复杂度优化
上述的
f
(
.
)
f(.)
f(.)是整个模型的核心计算部分,我们可以通过树的剪枝归并算法来实现对
f
(
.
)
f(.)
f(.)O(
n
3
n^3
n3)
复杂度到线性复杂度
算法
寻找最佳的合并点
example
这张图展示了长度为6的句子的处理过程。
m表示设定的剪枝的阈值
T
\Tau
T 是一个二维数组,用来盛放自底向上计算的所有cell。
上上述图示的三个function:
TREEINDUCTION 是前向计算的过程,调用PRUNING进行剪枝,PRUNING调用FIND寻找最佳消并点。
计算m之下的cell,如上图(b)显示。
当cell的row大于等于m时,还原所有以第m行的节点为root节点的子树,调用PRUNING进行剪枝操作,
剪枝的第一步是找到局部最佳的merge点(上图c),剪掉部分的cell(上图d),返回一个新的
T
\Tau
T(上图e)
在FIND中,最佳分割点的候选集合需要满足两个条件
(1)在
T
\Tau
T的第二行
(2)在以第m行的节点为root节点的子树中有被使用到
然后在候选集合中选择(x.p *pl *pr)最高的cell
T
i
,
j
\Tau_{i,j}
Ti,j做为最佳merge点,对应的将
T
i
,
∗
\Tau_{i,*}
Ti,∗和
T
∗
,
j
\Tau_{*,j}
T∗,j剪掉,得到
T
3
\Tau^3
T3
实验
预训练目标:
-
学习词汇表征,在实际实验中是对于word piece的表征,选择WikiText-2数据集,长度在128以内的句子,mask词汇,输入左子树和右子树的embedding进行词汇预测
因为剪枝操作,存在左子树或者右子树为空,以临近的最长子树来替代 -
无监督成分句法分析
在 WSJ and CTB 测试集计算F1
基于word-piece的word、NP等的召回