建议用最新版本的torch_geometric
,不同版本的API变动会比较大。
这个包最关键的一个类是MessagePassing
:
-
aggr
: 信息传递的方式,默认是add。也就是neighbor的信息聚合是加在center node上的(详见GCN原文)。 -
flow
: 信息传递的方向。这个要和后面的edge_index联合起来理解。默认为source_to_target
。
其他不做解释。
MessagePassing
的forward参数是:
最重要的是这个edge_index
,结合前面的flow
参数,edge_index
包含了你输入这个图的所有边的信息(start node、end node)。如图(黄色highlight的部分),输入的edge_index
一般情况下是LongTensor
,此时形状必须为[2, num_messages]
,第一维存放start node idx, 第二维存放end node idx。比方说:
## 假设图里面有三个节点,node index为 0,1,2
### 有向图:0->1, 1->2, 2->0
edge_index = [[0,1,2]
[1,2,0]]
### 无向图,全连接
edge_index = [[0,1,2,1,2,0]
[1,2,0,0,1,2]]
torch_geometric
目前实现了很多近几年的GCN变体 (GAT,RGCN,etc.), 都是继承自MessagePassing
, 只要理解了这个MessagePassing
和他的edge_index
,这些变体都可以直接调包用就可以了。
参考:
torch_geometric: MessagePassing