背景
本文是斯坦福大学发表在2017年nips的一篇文章,不同于deepwalk等通过图结构信息,在训练之前需要所有节点的embedding信息,这种方法对于那些没有见过的node节点是没办法处理的,概括的说,这些方法都是transductive的。此文提出的方法叫GraphSAGE,针对的问题是之前的网络表示学习的transductive,从而提出了一个inductive的GraphSAGE算法。GraphSAGE同时利用节点特征信息和结构信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射后的结果,而GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出。
GraphSAGE是为了学习一种节点表示方法,即如何通过从一个顶点的局部邻居采样并聚合顶点特征,而不是为每个顶点训练单独的embedding。这一点就注定了它跟其他方法不同的地方,对于新的节点信息,transdutive结构不能自然地泛化到未见过的顶点,而GraphSAGE算法可以动态的聚合出新节点的embeddinng信息[1]。
GraphSAGE概览
GraphSage是基于GCN的改进策略,它对GCN进行了两点改进:
- 子图随机组合训练:将GCN的全图训练方式改造成以节点为中心的小批量训练方式,使得模型可以在大规模的图数据上进行训练,并且可以在新的图结构上做预测,大大提供了工业场景的应用性
- 邻居聚合操作拓展:提出了替换卷积操作的其他集中提取邻居特征的方式
如图1所示,GraphSAGE是一种归纳式图学习模型,首先对节点进行邻居采样,接着聚合邻居的信息,最后进行最终任务的处理(如节点分类)。与一般基于矩阵分解的图嵌入方法不同,GraphSAGE利用节点的特征(如文本、节点的度、节点自身属性描述等)学习节点嵌入的模式(即函数),而不是直接学习节点的最终嵌入,因此可以利用学习好的这种节点嵌入模式处理未见过的节点甚至是未见过的新的图结构。
GraphSage的小批量算法
在之前的GCN模型中,训练方式是一种全图形式, 也就是一轮迭代,所有节点样本的损失只会贡献一次梯度数据,无法做到DNN中通常用到的小批量式更新,这从梯度更新的次数而言,效率是很低的。另外,对于很多实际的业务场景数据而言,图的规模往往是十分巨大的,单张显卡的显存容量很难达到一整张图训练时所 需的空间,为此采用小批量的训练方法对大规模图数据的训练进行分布式拓展是十分必要的。GraphSAGE 从聚合邻居的操作出发,对邻居进行随机采样来控制实际运算时节点k阶子图的数据规模,在此基础上对采样的子图进行随机组合来完成小批量式的训练。
在GCN模型中,我们知道节点在第(k+1)层的特征只与其邻居在k层的特征有关,这种局部性质使得节点在第k层的特征只与自己的k阶子图有关。对于下图中的中心节点(0号节点),假设GCN模型的层数为2,若要想得到其第2层特征,图中所有的节点都需要参与计算[2]。
小批量学习伪代码
聚合函数的选取
GraphSAGE研究了聚合邻居操作所需要的性质,(1)聚合操作必须要对聚合节点的数量做到自适应。不管节点的邻居数量怎么变化,进行聚合操作后输出的维度必须是一致的。(2)聚合操作对聚合节点具有排列不变性。这就要求不管邻居节点的排列顺序如何,输出的结果必须是一样的,比如Agg(v1,v2) = Agg(v2,v1)。基于这些性质然后提出了几种新的聚合(aggregator)操作。
平均聚合
LSTM聚合
需要注意的是,LSTM在本质上并不是对称的(也就是说,它们不是排列不变的),因为它们是按顺序处理输入的。我们通过简单地将LSTM应用于节点邻居的随机排列,使LSTM在无序集上进行聚合操作。
池化聚合
参考
- [1] https://www.jianshu.com/p/d743e298a2af
- [2] https://blog.csdn.net/qq_36852840/article/details/110259720