torch_geometric使用指南 (作个人纪录)

建议用最新版本的torch_geometric,不同版本的API变动会比较大。

这个包最关键的一个类是MessagePassing
torch_geometric使用指南 (作个人纪录)

  • aggr: 信息传递的方式,默认是add。也就是neighbor的信息聚合是加在center node上的(详见GCN原文)。
  • flow: 信息传递的方向。这个要和后面的edge_index联合起来理解。默认为source_to_target

其他不做解释。

MessagePassing的forward参数是:
torch_geometric使用指南 (作个人纪录)
最重要的是这个edge_index,结合前面的flow参数,edge_index包含了你输入这个图的所有边的信息(start node、end node)。如图(黄色highlight的部分),输入的edge_index一般情况下是LongTensor,此时形状必须为[2, num_messages],第一维存放start node idx, 第二维存放end node idx。比方说:

## 假设图里面有三个节点,node index为 0,1,2
### 有向图:0->1, 1->2, 2->0 
edge_index = [[0,1,2]
			  [1,2,0]]
### 无向图,全连接
edge_index = [[0,1,2,1,2,0]
			  [1,2,0,0,1,2]]

torch_geometric目前实现了很多近几年的GCN变体 (GAT,RGCN,etc.), 都是继承自MessagePassing, 只要理解了这个MessagePassing和他的edge_index,这些变体都可以直接调包用就可以了。

参考:
torch_geometric: MessagePassing

上一篇:ConvertJSONDateToJSDateObject 方法实现json格式时间串转换为 对应的时间格式串


下一篇:pytorch实现回归模型