Match-LSTM with Ans-Ptr论文笔记
《MACHINE COMPREHENSION USING MATCH-LSTM AND ANSWER POINTER》论文笔记
Overview
本文是在SQuAD v1.1数据集出世后第一个采用end-to-end的深度学习方法的paper。模型的主要结构是对已有的两个模型的结合:match-LSTM(Jiang&Wang, 2016)和Pointer Net(Vinyals et al., 2015)。相较于人工feature engineering + LR的传统机器学习方法,本文提出的方法在SQuAD数据集上取得了很大成功,在exact match和F1 score上都有质的飞跃。
Model Architecture
正如Overview中所提到的,模型的核心组件主要有两个:match-LSTM、Pointer Net。match-LSTM主要是用来提取query和passage之间的关系,Pointer Net主要是用来输出answer。具体来说,模型有三层:
- LSTM Preprocessing Layer
- match-LSTM Layer
- Answer Pointer Layer
根据Pointer Net中任务的不同,本文又将模型分为两种:Sequence Model和Boundary Model,后面会细说。
Figure 1
LSTM Preprocessing Layer
首先,文中用一个单向的LSTM分别对passage和query进行embedding,这里的embedding是独立的。
H
P
=
L
S
T
M
→
(
P
)
=
[
h
1
P
→
,
h
2
P
→
,
…
,
h
N
P
→
]
H
Q
=
L
S
T
M
→
(
Q
)
=
[
h
1
Q
→
,
h
2
Q
→
,
…
,
h
N
Q
→
]
H^{P}\ =\ \overrightarrow{LSTM}(P)\ =\ [\overrightarrow{h_1^{P}}, \overrightarrow{h_2^{P}}, \dots,\overrightarrow{h_N^{P}}] \\ H^{Q}\ =\ \overrightarrow{LSTM}(Q)\ =\ [\overrightarrow{h_1^{Q}}, \overrightarrow{h_2^{Q}}, \dots,\overrightarrow{h_N^{Q}}] \\
HP = LSTM
(P) = [h1P
,h2P
,…,hNP
]HQ = LSTM
(Q) = [h1Q
,h2Q
,…,hNQ
]
H
P
H^{P}
HP和
H
Q
H^{Q}
HQ分别代表passage和query的隐状态矩阵。
match-LSTM
match-LSTM模型(Bidirectional)也是本文作者提出的,原来是用于文本蕴含任务,即给定一个premise和一个hypothesis,让模型去判断premise和hypothesis之间的关系(蕴含关系 entailment / 矛盾关系 contradiction)。因此这里作者把query当作是premise,passage当作是hypothesis。而关系的计算使用的是attention mechanism。
G
i
→
=
t
a
n
h
(
W
q
H
q
+
(
W
P
h
i
P
+
W
r
h
i
−
1
r
+
b
p
)
⨂
e
Q
)
α
i
→
=
s
o
f
t
m
a
x
(
w
T
G
i
→
+
b
⨂
e
Q
)
\overrightarrow{G_i}\ =\ tanh(W^{q}H^{q}+(W^{P}h_{i}^{P}+W^{r}h_{i-1}^{r}+b^{p})\bigotimes e_{Q}) \\ \overrightarrow{\alpha_{i}}\ =\ softmax(w^T\overrightarrow{G_i}+b\bigotimes e_{Q})
Gi
= tanh(WqHq+(WPhiP+Wrhi−1r+bp)⨂eQ)αi
= softmax(wTGi
+b⨂eQ)
本文中所使用的attention的计算方式应该是additive,
W
q
、
W
p
、
W
r
W^{q}、W^{p}、W^{r}
Wq、Wp、Wr是三个参数矩阵,
⨂
e
Q
\bigotimes e_{Q}
⨂eQ表示将
Q
Q
Q个左边的式子concat起来。于是我们就得到了对于passage中的每个单词
h
i
P
h^{P}_{i}
hiP,它与query中所有单词之间的关系
α
i
→
\overrightarrow{\alpha_{i}}
αi
。然后用attention计算weighted sum
H
q
α
→
i
T
H^{q} \overrightarrow{\alpha}^{T}_{i}
Hqα
iT,最后将weighted sum和
i
i
i位置上passage中单词的hidden state concat起来得到
i
i
i时刻match-LSTM的输入
z
i
→
=
[
h
i
P
;
H
q
α
→
i
T
]
h
→
i
r
=
L
S
T
M
→
(
z
i
→
,
h
→
i
−
1
r
)
\overrightarrow{z_{i}}\ =\ [h_{i}^{P};H^{q} \overrightarrow{\alpha}^{T}_{i}] \\ \overrightarrow{h}_{i}^{r}\ =\ \overrightarrow{LSTM}(\overrightarrow{z_i},\overrightarrow{h}_{i-1}^{r})
zi
= [hiP;Hqα
iT]h
ir = LSTM
(zi
,h
i−1r)
反方向计算同理。其实到这里就不难发现match-LSTM本质上就是一个Seq2Seq模型。
Answer Pointer Layer
从Figure 1中可以看出,Answer Pointer Layer有两种形式,根据形式的不同模型也分别成为Sequence Model和Boundary Model。
Sequence Model
Sequence Model所作的任务是在passage中找到一个序列
a
=
(
a
1
,
a
2
,
…
)
a\ =\ (a_1,a_2,\dots)
a = (a1,a2,…)作为answer。由于answer的长度不是固定的,因此我们需要在原来的passage中加入一个end token(文中用0向量来表示),于是passage的长度变为
P
+
1
P+1
P+1,并且
H
r
H^{r}
Hr变为
H
^
r
=
[
H
r
;
0
]
\hat{H}^{r}=[H^{r};0]
H^r=[Hr;0]。接下来的流程与match-LSTM有很多相似之处。为了生成答案中的第
k
k
k个单词,我们依然仿照Seq2Seq Decoder的方式进行计算。令
β
k
,
j
\beta_{k,j}
βk,j表示选取passage中第
j
j
j个单词作为答案的第
k
k
k个单词的概率,
β
\beta
β的计算方式与上面相同
F
k
=
t
a
n
h
(
V
H
^
r
+
(
W
a
h
k
−
1
a
+
b
a
)
⨂
e
(
P
+
1
)
)
β
k
=
s
o
f
t
m
a
x
(
v
T
F
k
+
c
⨂
e
(
P
+
1
)
)
F_{k}\ =\ tanh(V\hat{H}^{r}+(W^{a}h^{a}_{k-1}+b^{a})\bigotimes e_{(P+1)})\\ \beta_{k}\ =\ softmax(v^{T}F_{k}+c \bigotimes e_{(P+1)})
Fk = tanh(VH^r+(Wahk−1a+ba)⨂e(P+1))βk = softmax(vTFk+c⨂e(P+1))
然后再将weighted sum
H
^
r
β
k
T
\hat{H}^{r} \beta_{k}^{T}
H^rβkT作为LSTM的输入得到第
k
k
k个answer的hidden state
h
k
a
=
L
S
T
M
(
H
^
r
β
k
T
,
h
k
−
1
a
)
h^{a}_{k}\ =\ LSTM(\hat{H}^{r} \beta_{k}^{T},h_{k-1}^{a})
hka = LSTM(H^rβkT,hk−1a)
本质上Pointer Net是一个语言模型,因此我们可以直接根据链式法则写出目标函数:
p
(
a
∣
H
r
)
=
∏
k
p
(
a
k
∣
a
1
,
a
2
,
…
,
a
k
−
1
,
H
r
)
p
(
a
k
=
j
∣
a
1
,
a
2
,
…
,
a
k
−
1
,
H
r
)
=
β
k
,
j
p(a|H^{r})\ =\ \prod_{k}p(a_k|a_1,a_2,\dots,a_{k-1},H^{r})\\ p(a_k=j|a_1,a_2,\dots,a_{k-1},H^{r})=\beta_{k,j}
p(a∣Hr) = k∏p(ak∣a1,a2,…,ak−1,Hr)p(ak=j∣a1,a2,…,ak−1,Hr)=βk,j
然后用MLE,优化目标函数得到参数即可。
Boundary Model
Boundary Model和Sequence Model唯一的区别就是Boundary Model预测的是answer的start和end,然后取start到end之间的单词作为答案。这种方式叫做span extraction,也是在SQuAD数据集最常用的方式。此时目标函数也就变成了
p
(
a
∣
H
r
)
=
p
(
a
s
∣
H
r
)
p
(
a
e
∣
a
s
,
H
r
)
p(a|H^{r})\ =\ p(a_s|H^{r})p(a_e|a_s,H^{r})
p(a∣Hr) = p(as∣Hr)p(ae∣as,Hr)
Experiment
在实验中,作者用Glove作为word embedding。
实验结果很清晰地说明了本文提出的方法相较于LR的baseline有了巨大的提升,EM 67.9%、 F1 77.0%。
同时作者将attention进行了可视化
也能够看出来attention的确是比较准确的把握到了query与passage中单词之间的联系。