Graph Convolution Network 理解与实现
https://zhuanlan.zhihu.com/p/51990489
Graph Convolution作为Graph Networks的一个分支,可以说几乎所有的图结构网络都是大同小异,详见综述,而Graph Convolution Network又是Graph Networks中最简单的一个分支。理解了它便可以理解很多近年来的图结构网络,比如Scene Graph Generation中的Message Passing机制等。后续打算持续更新一些原始GCN的变体。
【相关文章和网站】:
- Paper: Semi-Supervised Classification with Graph Convolutional Networks, 2016
- Paper: Gated Graph Sequence Neural Networks, 2016
- Website: How powerful are Graph Convolutional Networks?
- Github: 关于Gated Graph Convolution Network的Pytorch实现 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0
- 其实Graph Convolution Network (GCN)可以看作Graph Networks的一个分支(只有Node feature,无Edge feature和global attribute),而Graph Networks则有一篇2018年的综述:Relational inductive biases, deep learning, and graph networks, 2018
【Graph Convolution Network和传统CNN的关系】
我们不妨把传统的CNN的输入图片\(I\)也定义为一个Graph,他包含一堆Pixel集合\({p_i}\)看作是Node, 而graph的边则是通过pixel的连通性定义的,所以每个pixel有至多8个edge和他相连。而Convolution其实就是把他的8个neighbour pixel的feature和他自己的feature乘以一个可学习的参数化kernel,来update这个pixel的feature.
那么由此,就不难理解GCN了。GCN主要的区别在于,他的node间的边,不是通过连通性定义的,而是需要给定了一个edge set,或者说graph的adjacent matrix。而且由于每个node可以有任意数量的neighbour node,所以update feature时,所有node其实是乘以了同一套参数。
【公式化】
这里我们参考Semi-Supervised Classification with Graph Convolutional Networks, 2016给出Graph Convolution的最终公式,忽略了原文的推导过程。
GCN可以定义为如下公式:
\[Z=GCN(X,A) \]- 这里\(X\in R^{N\times C}\)是node输入,包含N个node,每个Node有C维的feature, A是Adjacent matrix,,\(A_{ij}\)定义node i和node j间是否有边edge,\(Z\in R^{N\times F}\)是输出, F表示新特征的维度
详细展开如下:
\[Z=\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}}X\Theta \]- 这里\(\Theta\in R^{C\times F}\)就是要学习的参数,\(\hat{A}=A+I\),\(I\)是单位矩阵,\(\hat{D}_{ii}=\sum_{j}\hat{A}_{ij}\)是对角线矩阵,对角线上每个元素表示,node i的neighbor数(包括自身)。所以其实等式右边可以看作\(X\Theta\)把所有node的feature从C映射到F维,而每个node的新feature\(Z_i\in R^F\)等于所有和他相连的node(包括自身)的F维feature的加权和,即average
【伪代码实现】
\[input : X, A, output : Z \] \[Y = f_c(X) \] \[Z = (A+I) * Y / (A\cdot sum(1)+1) \]【Gated Graph Convolution Network】
但是上述Node特征更新的方式比较原始,Gated Graph Sequene Neural Networks, ICLR, 2016将Graph Convolution的X to Z的更新改成了GRU(LSTM)的形式。同时设计了一个Graph-Level的特征。下面实现参考了上文的思想,但做了些简化,比如原文将Incoming Edges和Outgoing Edges区分了这里我就沿用朴素Graph Convolution的A,不做拓展。
【Gated Graph Convolution Network 公式&伪代码】
\[input:X^t,output:X^{t+1}(即Z) \] \[Y=A\ast f_c(X^t) \] \[U=\sigma(W_1Y+W_2X^t) \] \[R=\sigma(W_3Y+W_4X^t) \] \[X^{t+1}_{tem}=tanh(W_5Y+W_6(R\cdot X^t)) \] \[X^{t+1}=(1-U)\cdot X^t+U\cdot X^{t+1}_{tem} \]- 上述W都是可学习的参数
【Graph-Level特征获取】
很多应用需要将一整个graph整合成一个特征,而原始的Graph Convolution则只能生成每个node的特征。graph-level的定义如下:
\[h_G=tanh(\sum_{nodes}\sigma([X^T,X^0]))\cdot tanh(f_{c_2}([X^T,X^0]) \]当然,还有很多文章,采取更为简单的graph-level feature提取方法:
\[h_G=\sum_{nodes}X^T_i,or h_G=\frac{1}{Num(nodes)}\sum_{nodes}X^T_i \]【Code】
关于Gated Graph Convolution Network的代码,可以参考以下Github项目 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0