DAG-GNN: DAG Structure Learning with Graph Neural Networks

Yu Y., Chen J., Gao T. and Yu M. DAG-GNN: DAG structure learning with graph neural networks. In International Conference on Machine Learning (ICML), 2019.

有向无环图 + GNN + VAE.

主要内容

先前已经有工作(NOTEARS)讨论了如何处理线性SEM模型

\[X = A^TX + Z, \]

\(A \in \mathbb{R}^{m \times m}\)为加权的邻接矩阵, \(m\)代表了有向无环图中变量的数目, \(Z\)是独立的noise. 需要特别说明的是, 在本文中, 作者假设每一个结点变量\(X_i\)并非传统的标量而是一个向量 (个人觉得这是很有意思的点, 有点胶囊的感觉), 故\(X \in \mathbb{R}^{m \times d}\), 这里\(X_i\)\(X\)的第\(i\)行.

本文在此基础上更进一步, 考虑非线性的情况:

\[g(X) = A^Tg(X) + f_1(Z), \]

如果\(g\)可逆, 则可以进一步表示为

\[X = f_2((I - A^T)^{-1}f_1(Z)). \]

为了满足这一模型, 作者套用VAE, 进而最大化ELBO:

\[\mathcal{L}_{\mathrm{ELBO}} = \mathbb{E}_{q_{\phi}(Z|X)}[\log p_{\theta}(X|Z)] - \mathbb{D}_{\mathrm{KL}}(q_{\phi}(Z|X)\| p(Z)), \]

整个VAE的流程是这样的:

DAG-GNN: DAG Structure Learning with Graph Neural Networks

  1. encoder:

    \[M_Z, \log S_Z = f_4((I - A^T)f_3(X)), \Z \sim \mathcal{N}(M_Z, S_Z^2). \]

  2. decoder

\[M_X, S_X = f_2((I - A^T)^{-1}f_1(Z)), \\widehat{X} \sim \mathcal{N}(M_X, S_X^2). \]

注: 因为每个结点变量都不是标量, 所以考虑上面的流程还是把\(X, Z\)拉成向量\(md\)再看会比较清楚.

此时

\[\mathbb{D}_{\mathrm{KL}}(q_{\phi}(Z|X)\|p(Z)) = \frac{1}{2} \sum_{i=1}^m \sum_{j=1}^d \{[S_Z]_{ij}^2 + [M_Z]_{ij}^2 - 2\log [S_Z]_{ij} - 1 \}. \]

仅最大化ELBO是不够的, 因为这并不能保证\(A\)反应有向无环图, 所以我们需要增加条件

\[h(A) = \mathrm{tr}[(I+\alpha A \circ A)^m] = m, \]

具体推导看NOTEARS, 这里\(\alpha=\frac{c}{m}\), \(c>0\)是一个超参数, 这个原因是

\[(1 + \alpha |\lambda|)^m \le e^{c|\lambda|}, \]

所以合适的\(c\)能够让条件更加稳定.

最后目标可以总结为:

\[\min_{\phi, \theta, A} \quad -\mathcal{L}_{\mathrm{ELBO}} \\mathrm{s.t.} \quad h(A) = 0. \]

同样的, 作者采用了augmented Lagrangian来求解

\[(A^k, \phi^k, \theta^k) = \arg \min_{A,\phi, \theta} \: -\mathcal{L}_{\mathrm{ELBO}} + \lambda h(A) + \frac{c}{2}|h(A)|^2, \\lambda^{k+1} = \lambda^k + c^k h(A^k), \c^{k+1} = \left \{ \begin{array}{ll} \eta c^k, & \mathrm{if} \: |h(A^k)| > \gamma |h(A^{k-1})|, \c^k, & otherwise. \end{array} \right. \]

这里\(\eta > 1, \gamma < 1\), 作者选择\(\eta=10, \gamma=1/4\).

注: \(c\)逐渐增大的原因是, 显然当\(c = +\infty\)的时候, \(h(A)\)必须为0.

注: 作者关于图神经网络的部分似乎就集中在\(X\)的模型上, 关于图神经网络不是很懂, 就不写了.

代码

原文代码

DAG-GNN: DAG Structure Learning with Graph Neural Networks

上一篇:CSS 浮动及应用,清除浮动


下一篇:spring security