Python使用ctypes模块调用C/C++

最近在做图卷积相关的实验,里面涉及到图采样,该过程可以抽象为:从一个包含n个节点,m条边的图中根据一定规则采样一个连通图。由于实验使用的是FB15k-237数据集,共包含14541个节点,272115条边,每次采样30000条边,采样一次需要8s,这对于深度学习实验来说是难以接受的,会导致GPU长时间空闲。因此我开始尝试使用C/C++优化代码,虽然最后优化效果不行,但是也是对python调用C代码的一次学习,因此在此纪录一下。

Python原代码

 def get_adj_and_degrees(num_nodes, triplets):
    """ Get adjacency list and degrees of the graph"""
    adj_list = [[] for _ in range(num_nodes)]
    for i, triplet in enumerate(triplets):
        adj_list[triplet[0]].append([i, triplet[2]])
        adj_list[triplet[2]].append([i, triplet[0]])

    degrees = np.array([len(a) for a in adj_list])
    adj_list = [np.array(a) for a in adj_list]
    return adj_list, degrees

这里以get_adj_and_degrees函数为例,我们使用C/C++优化该函数。该函数只是起演示作用,具体细节不重要。

C/C++实现代码

我们在sampler.hpp中对该函数进行优化,该文件的定义如下:

#ifndef SAMPLER_H
#define SAMPLER_H

#include <vector>
#include "utility.hpp"

using namespace std;

// global graph data
int num_node = 0;
int num_edge = 0;
vector<int> degrees; // shape=[N]
vector<vector<vector<int>>> adj_list; // shape=[N, variable_size, 2]


void build_graph(int* src, int* rel, int* dst, int num_node_m, int num_edge_m) {
    num_node = num_node_m;
    num_edge = num_edge_m;

    // resize the vectors
    degrees.resize(num_node);
    adj_list.resize(num_node);

    for (int i = 0; i < num_edge; i++) {
        int s = src[i];
        int r = rel[i];
        int d = dst[i];

        vector<int> p = {i, d};
        vector<int> q = {i, s};
        adj_list[s].push_back(p);
        adj_list[d].push_back(q);
    }

    for (int i = 0; i < num_node; i++) {
        degrees[i] = adj_list[i].size();
    }
}

#endif

这里C/C++函数把结果作为全局变量进行存储,是为了后一步使用。具体的函数细节也不在讲述,因为我们的重点是如何用python调用。

生成so库

ctypes只能调用C函数,因此我们需要把上述C++函数导出为C函数。因此我们在lib.cpp中做出如下定义:

#ifndef LIB_H
#define LIB_H

#include "sampler.hpp"

extern "C" {
    void build_graph_c(int* src, int* rel, int* dst, int num_node, int num_edge) {
        build_graph(src, rel, dst, num_node, num_edge);
    }
}

#endif

然后使用如下命令进行编译,为了优化代码,加上了O3march=native选项:

g++ lib.cpp -fPIC -shared -o libsampler.so -O3 -march=native

Python调用C/C++函数

上一篇:什么是smart原则,有什么好的例子


下一篇:L2-031 深入虎穴 (25分)