Graph Convolutional Network
Author:
, Minjie Wang, Yu Gai, Quan Gan, Zheng Zhang
这是使用DGL实施图卷积网络的简要介绍(Kipf & Welling et al., Semi-Supervised Classification with Graph Convolutional Networks)。我们以DGLGraph上较早的教程为基础,并演示DGL如何将图与深度神经网络相结合并学习结构表示。
Model Overview
从消息传递的角度看GCN
我们从消息传递的角度描述了图卷积神经网络的一层; 数学可以在这里找到。 对于每个节点u,它归结为以下步骤:
1)汇总邻居的表示 hv 产生中间表示 h^u。2)转换汇总表示h^u线性投影,然后非线性: hu=f(Wuh^u)。
我们将通过DGL消息传递实现第1步,并通过apply_nodes方法实现第2步,该 方法的节点UDF将是PyTorch nn.Module。
使用DGL的GCN实现
我们首先定义消息并像往常一样减少功能。由于聚合在一个节点上u 只涉及总结邻居的表象 hv,我们可以简单地使用内置函数:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
然后,我们为定义节点UDF apply_nodes,它是一个完全连接的层:
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
if self.activation is not None:
h = self.activation(h)
return {'h' : h}
然后,我们继续定义GCN模块。GCN层本质上在所有节点上执行消息传递,然后应用NodeApplyModule。请注意,为简单起见,我们省略了本文中的缺失。
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
前向功能与PyTorch中任何其他常见的NN模型相同。我们可以像一样初始化GCN nn.Module。例如,让我们定义一个由两个GCN层组成的简单神经网络。假设我们正在训练cora数据集的分类器(输入要素大小为1433,类别数为7)。最后一个GCN层计算节点嵌入,因此最后一个层通常不应用激活。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, None)
def forward(self, g, features):
x = self.gcn1(g, features)
x = self.gcn2(g, x)
return x
net = Net()
print(net)
out:
Net(
(gcn1): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=1433, out_features=16, bias=True)
)
)
(gcn2): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=16, out_features=7, bias=True)
)
)
)
我们使用DGL的内置数据模块加载cora数据集。
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
train_mask = th.BoolTensor(data.train_mask)
test_mask = th.BoolTensor(data.test_mask)
g = data.graph
# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, train_mask, test_mask
训练模型后,我们可以使用以下方法评估模型在测试数据集上的性能:
def evaluate(model, g, features, labels, mask):
model.eval()
with th.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = th.max(logits, dim=1)
correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
然后,我们按照以下方式训练网络:
import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(50):
if epoch >=3:
t0 = time.time()
net.train()
logits = net(g, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >=3:
dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask)
print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)))
out:
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9444 | Test Acc 0.1470 | Time(s) nan
Epoch 00001 | Loss 1.9181 | Test Acc 0.1610 | Time(s) nan
Epoch 00002 | Loss 1.8911 | Test Acc 0.1900 | Time(s) nan
Epoch 00003 | Loss 1.8613 | Test Acc 0.2360 | Time(s) 0.0839
Epoch 00004 | Loss 1.8310 | Test Acc 0.2630 | Time(s) 0.0840
Epoch 00005 | Loss 1.8007 | Test Acc 0.2850 | Time(s) 0.0837
Epoch 00006 | Loss 1.7705 | Test Acc 0.2980 | Time(s) 0.0836
Epoch 00007 | Loss 1.7408 | Test Acc 0.3100 | Time(s) 0.0838
Epoch 00008 | Loss 1.7109 | Test Acc 0.3170 | Time(s) 0.0837
Epoch 00009 | Loss 1.6810 | Test Acc 0.3280 | Time(s) 0.0839
Epoch 00010 | Loss 1.6513 | Test Acc 0.3550 | Time(s) 0.0841
Epoch 00011 | Loss 1.6219 | Test Acc 0.3760 | Time(s) 0.0841
Epoch 00012 | Loss 1.5942 | Test Acc 0.3910 | Time(s) 0.0840
Epoch 00013 | Loss 1.5674 | Test Acc 0.4030 | Time(s) 0.0841
Epoch 00014 | Loss 1.5413 | Test Acc 0.4140 | Time(s) 0.0841
Epoch 00015 | Loss 1.5157 | Test Acc 0.4270 | Time(s) 0.0842
Epoch 00016 | Loss 1.4912 | Test Acc 0.4430 | Time(s) 0.0842
Epoch 00017 | Loss 1.4676 | Test Acc 0.4520 | Time(s) 0.0843
Epoch 00018 | Loss 1.4451 | Test Acc 0.4600 | Time(s) 0.0842
Epoch 00019 | Loss 1.4233 | Test Acc 0.4640 | Time(s) 0.0842
Epoch 00020 | Loss 1.4021 | Test Acc 0.4730 | Time(s) 0.0842
Epoch 00021 | Loss 1.3815 | Test Acc 0.4760 | Time(s) 0.0842
Epoch 00022 | Loss 1.3616 | Test Acc 0.4810 | Time(s) 0.0842
Epoch 00023 | Loss 1.3423 | Test Acc 0.4890 | Time(s) 0.0842
Epoch 00024 | Loss 1.3236 | Test Acc 0.5080 | Time(s) 0.0842
Epoch 00025 | Loss 1.3056 | Test Acc 0.5180 | Time(s) 0.0843
Epoch 00026 | Loss 1.2881 | Test Acc 0.5240 | Time(s) 0.0843
Epoch 00027 | Loss 1.2713 | Test Acc 0.5310 | Time(s) 0.0844
Epoch 00028 | Loss 1.2550 | Test Acc 0.5400 | Time(s) 0.0843
Epoch 00029 | Loss 1.2392 | Test Acc 0.5570 | Time(s) 0.0844
Epoch 00030 | Loss 1.2238 | Test Acc 0.5670 | Time(s) 0.0844
Epoch 00031 | Loss 1.2089 | Test Acc 0.5800 | Time(s) 0.0843
Epoch 00032 | Loss 1.1944 | Test Acc 0.5860 | Time(s) 0.0843
Epoch 00033 | Loss 1.1803 | Test Acc 0.5960 | Time(s) 0.0843
Epoch 00034 | Loss 1.1666 | Test Acc 0.6000 | Time(s) 0.0843
Epoch 00035 | Loss 1.1532 | Test Acc 0.6070 | Time(s) 0.0843
Epoch 00036 | Loss 1.1401 | Test Acc 0.6160 | Time(s) 0.0843
Epoch 00037 | Loss 1.1273 | Test Acc 0.6220 | Time(s) 0.0843
Epoch 00038 | Loss 1.1147 | Test Acc 0.6240 | Time(s) 0.0843
Epoch 00039 | Loss 1.1023 | Test Acc 0.6310 | Time(s) 0.0843
Epoch 00040 | Loss 1.0901 | Test Acc 0.6340 | Time(s) 0.0843
Epoch 00041 | Loss 1.0782 | Test Acc 0.6400 | Time(s) 0.0843
Epoch 00042 | Loss 1.0664 | Test Acc 0.6410 | Time(s) 0.0843
Epoch 00043 | Loss 1.0548 | Test Acc 0.6460 | Time(s) 0.0842
Epoch 00044 | Loss 1.0434 | Test Acc 0.6470 | Time(s) 0.0842
Epoch 00045 | Loss 1.0322 | Test Acc 0.6520 | Time(s) 0.0842
Epoch 00046 | Loss 1.0211 | Test Acc 0.6600 | Time(s) 0.0842
Epoch 00047 | Loss 1.0101 | Test Acc 0.6600 | Time(s) 0.0841
Epoch 00048 | Loss 0.9993 | Test Acc 0.6650 | Time(s) 0.0841
Epoch 00049 | Loss 0.9886 | Test Acc 0.6670 | Time(s) 0.0841
GCN in one formula
在数学上,GCN模型遵循以下公式:
H(l+1)=σ(D~−21A~D~−21H(l)W(l))
这里, H(l) 表示lth 网络中的层 σ 是非线性的,并且 W 是该层的权重矩阵。 D 和 A如通常所见,分别代表度矩阵和邻接矩阵。〜是一种重新规范化的技巧,其中,我们向图的每个节点添加了自连接,并构建了相应的度数和邻接矩阵。输入的形状 H(0)是 N×D,在哪里 N 是节点数,并且 D是输入要素的数量。我们可以将多层链接起来,以生成形状为:math‘N 乘以 F‘的节点级表示输出,其中F 是输出节点特征向量的维。
可以使用稀疏矩阵乘法内核(例如Kipf的pygcn代码)有效地实现该方程 。实际上,由于使用内置函数,上述DGL实现实际上已经使用了该技巧。要了解其内幕,请阅读我们在PageRank上的教程。
脚本的总运行时间:(0分钟17.986秒)
平湖片帆 发布了0 篇原创文章 · 获赞 1 · 访问量 36 私信 关注