本文记录因子分析机FM算法的推导和理解笔记
论文地址
https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf
FM 推导过程
FM在预测任务是考虑了不同特征之间的交叉情况, 以2阶的交叉为例:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
W
x
i
x
j
(1)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}Wx_ix_j \tag{1}
y^(x)=w0+i=1∑nwi∗xi+i=1∑nj=i+1∑nWxixj(1)
其中的
w
0
w_0
w0,
w
i
w_i
wi,
W
W
W是模型需要学习的内容。由于在实际场景中,
x
i
x_i
xi,
x
j
x_j
xj都是维度很大并且稀疏的one-hot类型的向量,如果直接学习交叉项的权重
W
W
W很容易过拟合。
但是注意到
W
W
W应该是一个实对称的矩阵,由实对称矩阵理论的性质:
每个实对称矩阵
A
A
A可以分解成这样一种形式:
A
=
Q
Λ
Q
T
A=Q\Lambda Q^T
A=QΛQT
进而
W
W
W可以被分解成
W
=
V
V
T
W=VV^T
W=VVT,其中
V
∈
R
n
×
k
V \in R^{n \times k}
V∈Rn×k
所以式子(1)可以化成:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
<
v
i
,
v
j
>
x
i
x
j
(2)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n} <v_i, v_j>x_ix_j \tag{2}
y^(x)=w0+i=1∑nwi∗xi+i=1∑nj=i+1∑n<vi,vj>xixj(2)
v
i
v_i
vi和
v
j
v_j
vj 可以用长度为k的向量表示:
<
v
i
,
v
j
>
=
∑
f
=
1
k
v
i
,
f
⋅
v
j
,
f
<v_i, v_j> = \sum_{f=1}^{k}v_{i,f} \cdot v_{j,f}
<vi,vj>=∑f=1kvi,f⋅vj,f
所以有:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
∑
f
=
1
k
v
i
,
f
⋅
v
j
,
f
x
i
x
j
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{f=1}^{k}v_{i,f} \cdot v_{j,f}x_ix_j
y^(x)=w0+∑i=1nwi∗xi+∑i=1n∑j=i+1n∑f=1kvi,f⋅vj,fxixj
直接求解这个算法的时间复杂度为
O
(
k
n
2
)
O(kn^2)
O(kn2),但是可以通过调整求解方式将复杂度降为
O
(
k
n
)
O(kn)
O(kn)