SAGEConv

SAGEConv

论文名称:Inductive Representation Learning on Large Graphs

论文链接:https://arxiv.org/pdf/1706.02216.pdf

现在存在方法具有内在transductive,不能generalize未见到的节点,。我们提出GraphSAGE是基于inductive开发的,对于未见到数据具备泛化能力。在训练的过程中,对邻居的节点采用抽样的方式,而不是对所有的节点进行训练。

GraphSAGE学习的主要方法是收集局部的邻居信息,例如度特征或者邻居节点的属性特征。接下来,首先描述推理算法,假设参数已经学习,如何生成节点的Embedding。然后,我们在讲一下如何通过随机梯度下降法和后向传播算法学习模型的参数。

1、Embeddin*生算法(前向传播算法)

假设我们已经学习K 聚合函数( A G G R E G A T E k , k ∈ { 1 , … , K } AGGREGATE_k,k\in\{1,\dots,K\} AGGREGATEk​,k∈{1,…,K})的参数和一系列的权重矩阵 W k , ∀ k ∈ { 1 , … , K } W^k,\forall k \in\{1, \ldots, K\} Wk,∀k∈{1,…,K}, 其中K指搜索的深度,这些参数主要用于前向传播。

SAGEConv

Algorithm 1是在一次迭代、深度搜索并aggregate局部邻居的信息,随着迭代次数的增加,能够获取到更深层次的信息。

G = ( V , E ) \mathcal{G}=(\mathcal{V},\mathcal{E}) G=(V,E)表示整张图,

x v , ∀ v ∈ V \text{x}_v, \forall v\in\mathcal{V} xv​,∀v∈V表示节点的特征。

k k k表示当前step下,每个节点 v ∈ V v\in\mathcal{V} v∈V 收集邻居节点特征表示 h u k − 1 , ∀ u ∈ N ( v ) \text{h}_u^{k-1},\forall u\in\mathcal{N(v)} huk−1​,∀u∈N(v),生成单一的节点表示 h N ( v ) k − 1 \text{h}_{N(v)}^{k-1} hN(v)k−1​. 注意本次迭代aggregate取决于前一次迭代的输出。其中bad case k = 0 k=0 k=0是节点的特征输入。将汇总的邻居向量 h N ( v ) k − 1 \text{h}_{\mathcal{N(v)}}^{k-1} hN(v)k−1​和当前的节点特征 h v k − 1 h_v^{k-1} hvk−1​进行拼接, 进行全连接层和非线性激活函数 σ \sigma σ的转换,生成的结果作为下一次迭代的输入。最终输出的特征表示为 z v ≡ h v K , ∀ v ∈ V \mathbf{z}_{v} \equiv \mathbf{h}_{v}^{K}, \forall v \in \mathcal{V} zv​≡hvK​,∀v∈V。

在minibatch的设置中,对邻居节点和边进行采样。相对于全部节点的向量的计算,GraphSAGE只是对必要的节点minibatch集合 B \mathcal{B} B进行计算。

SAGEConv

主要的思想就是抽样出需要的节点进行计算,Algorithm的Line 2-7描述了抽样的过程。

从Line1-6看出: 每个 B k \mathcal{B}^k Bk包含节点 v ∈ B k + 1 v\in \mathcal{B}^{k+1} v∈Bk+1的 表示。

Line9-15 描述聚合过程, N k ( u ) \mathcal{N}_k(u) Nk​(u)采用独立的均匀采样。

Relation to the Weisfeiler-Lehman Isomorphism Test. GraphSAGE的灵感来自同构图检验的经典算法。在Algorithm中,我们 (i)设 K = ∣ V ∣ K=|V| K=∣V∣ (ii) 边权重是相等的。(iii) 使用Hash函数作为aggregator。如果两个子图输出 { z v , ∀ v ∈ V } \{\text{z}_v, \forall v\in\mathcal{V}\} {zv​,∀v∈V}是相同的,我们认为两个子图是同构的。当然,我们的目标学习节点的表示,不是测试是否同构。

Neighborhood defination 均匀采样固定大小的邻居节点的数量, 即 N ( v ) \mathcal{N(v)} N(v)是固定的,每次迭代均匀采样不同的样本。如果不采样,一个Batch的大小为 O ( ∣ V ∣ ) O{(|\mathcal{V}|)} O(∣V∣). 采样后, GraphSAGE的复杂度固定在 O ( ∏ i = 1 K S i ) O\left(\prod_{i=1}^{K} S_{i}\right) O(∏i=1K​Si​), 其中 S i , i ∈ { i , ⋯   , K } S_i, i\in\{i,\cdots,K\} Si​,i∈{i,⋯,K}, K K K用户自定义的。在实际应用中,一般 K = 2 K=2 K=2, S 1 ⋅ S 2 ≤ 500 S_{1} \cdot S_{2} \leq 500 S1​⋅S2​≤500。

2、GraphSAGE参数学习

使用graph-based loss function学习节点的表示, z u , ∀ u ∈ V \text{z}_u, \forall u\in\mathcal{V} zu​,∀u∈V, 学习权重矩阵 W k , ∀ k ∈ { 1 , ⋯   , K } W^k, \forall k\in\{1,\cdots, K\} Wk,∀k∈{1,⋯,K}, 采用随机梯度下降的方法。graph-based loss function会使得相近的节点有相同的表示,同时兼顾不同的节点学习表示不相同的。
J G ( z u ) = − log ⁡ ( σ ( z u ⊤ z v ) ) − Q ⋅ E v n ∼ P n ( v ) log ⁡ ( σ ( − z u ⊤ z v n ) ) (1) J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)\tag{1} JG​(zu​)=−log(σ(zu⊤​zv​))−Q⋅Evn​∼Pn​(v)​log(σ(−zu⊤​zvn​​))(1)
其中,节点v和节点u是在固定长度随机游走过程共现的。 σ \sigma σ是激活函数, P n P_n Pn​是负采样的分布, Q Q Q是定义负样本的数量, 损失函数输入表示 z u \text{z}_u zu​包含来自邻居节点的特征。

3、Aggregator Architectures

节点的邻居和文本、图像不一样,他们的无序、对称的,测试如下三个aggregator functions:

Mean aggregator

用以下公式替换Algorithm1中Line4和Line5
h v k ← σ ( W ⋅ MEAN ⁡ ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) (2) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right.\tag{2} hvk​←σ(W⋅MEAN({hvk−1​}∪{huk−1​,∀u∈N(v)})(2)
计算当前节点 h v k − 1 \mathbf{h}_v^{k-1} hvk−1​和邻居节点拼接起来, 计算均值,这种操作可以将不同深度的节点进行"skip connnection"。

LSTM aggregator

LSTM不具备对称性,简单地将邻居节点处理成无序的序列作为输入。

Pooling aggregator
 AGGREGATE  k pool  = max ⁡ ( { σ ( W pool  h u i k + b ) , ∀ u i ∈ N ( v ) } ) (3) \text { AGGREGATE }_{k}^{\text {pool }}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right)\tag{3}  AGGREGATE kpool ​=max({σ(Wpool ​hui​k​+b),∀ui​∈N(v)})(3)

上一篇:GantD - 专注于数据密集型业务场景|基于AntD聚合型React组件库


下一篇:RISC-V生态全景解析(十六):YoC组件发布开源操作指南