一、背景
在SO-Net的分类模型的Encoder部分,first_pointnet给出特征向量后,有一个求索引的操作:
M = node.size()[2]
with torch.cuda.device(self.first_pn_out.get_device()):
gather_index = index_max.forward_cuda(self.first_pn_out.detach(),
min_idx.int(),
M).detach().long()
这里的index_max在./models/index_max_ext
文件下,是pytorch的CUDA扩展,接口为C++接口。因为没有CUDA编程基础,只好找了相关的书籍,简单了解一下基本知识,看了相关的博客和帖子,逐渐明白了这个扩展程序的功能是什么。
二、具体实现
2.1 文件结构
实现一个pytorch的CUDA扩展,并以C++为接口,至少需要三个部分,setup.py,.cu以及.cpp文件,具体到index_max上就是setup.py,index_max.cpp和index_max_cuda.cu文件,下面一一介绍这三个文件具体的作用。
2.2 setup.py
这个文件是为了编译后面的文件用的。其具体代码如下:
import setuptools
import torch
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
setup(name='index_max', # 编译后的链接库名称
ext_modules=[CUDAExtension('index_max', ['index_max.cpp', 'index_max_cuda.cu'])], # 待编译文件以及编译函数
cmdclass={'build_ext': BuildExtension}) # 执行编译命令设置
2.3 index_max.cpp
在原来的文件中有些部分没有使用,就不介绍了,直接解释用到的部分。在.cpp文件中声明了.cu中用到的函数,实现函数具体功能。
#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <thread>
torch::Tensor index_max_forward_cpu(const torch::Tensor data,const torch::Tensor index,const int K)
{
int B = data.size(0);
int C = data.size(1);
int N = data.size(2);
torch::Tensor max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32).requires_grad(false));
torch::Tensor max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32)) * -1000.0;
auto data_a = data.accessor<float, 3>();
auto index_a = index.accessor<int, 2>();
auto max_idx_a = max_idx.accessor<int, 3>();
auto max_val_a = max_val.accessor<float, 3>();
for (int b=0; b<B; ++b)
{
for (int c=0; c<C; ++c)
{
for (int n=0; n<N; ++n)
{
int k = index_a[b][n];
float data_point = data_a[b][c][n];
if (data_point > max_val_a[b][c][k])
{
max_val_a[b][c][k] = data_point;
max_idx_a[b][c][k] = n;
}
}
}
}
return max_idx;
}
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor/variable")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor index_max_forward_cuda_wrapper(const torch::Tensor data,const torch::Tensor index,const int K){
CHECK_INPUT(data);
CHECK_INPUT(index);
return index_max_forward_cuda(data, index, K);
}
//绑定
PYBIND11_MODULE(index_max, m)
{
m.def("forward_cuda", &index_max_forward_cuda_wrapper, "CUDA code without shared memory");
}
这个函数传入的参数为:
self.first_pn_out.detach() # [B,C,kN]
min_idx.int(): # [B,kN]
M # 节点数
程序中,k=index_a[b][n]
,表示某个batch中某个点的某个近邻节点的索引,data_point是与之对应的那个节点的某一维度的特征值,如果该特征值大于已有的值,则更新该值,并记录对应的索引。这等于:将以每个节点为近邻节点的的所有点(每个节点对应的点数量不同)进行max_pooling,举个例子,假设原始点云编号从0—1024,节点编号从0—64。对于1号节点,有编号为0,123, 589, 1024共四个点以其为近邻节点。那么该函数的作用就是对这四个节点的特征向量做一次max_pooling,其结果作为0号节点的特征。只不过在程序中返回了最大值对应的点的索引。结合后面的代码可以看出,该函数的作用就是对于每个节点的对应的原始点的特征向量做一次max_pooling,其输出作为新的节点特征向量。
self.first_pn_out_masked_max = self.first_pn_out.gather(dim=2, index=gather_index * mask_row_max.unsqueeze(1).long()) # BxCxM
2.4 index_max_cuda.cu
CUDA编程其目的是为程序的运行进行加速,提高效率,代码的主要部分就是.cpp文件中具体功能的加速版本。
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
__global__ void index_max_forward_cuda_kernel(const float* __restrict__ data,
const int* __restrict__ index,
int* __restrict__ max_idx,
float* __restrict__ max_val,
const int B, const int C, const int N, const int K){
int b = threadIdx.x; //线程id 0 ~ B-1
int c = blockIdx.x; //块id 0 ~ C-1
for(int n=0;n<N;++n){
int k = index[b*N+n];
float data_point = data[b*C*N+c*N+n];
if (data_point > max_val[b*C*K+c*K+k]){
max_val[b*C*K+c*K+k] = data_point;
max_idx[b*C*K+c*K+k] = n;
}
}
}
torch::Tensor index_max_forward_cuda(const torch::Tensor data, const torch::Tensor index, const int K){
int B = data.size(0); // batch_size
int C = data.size(1); // feature channels
int N = data.size(2); // number of points
auto device_idx = data.device().index();
auto max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_idx)); // index
auto max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_idx)) * -1000.0; // value
index_max_forward_cuda_kernel<<<C, B>>>(data.data<float>(),
index.data<int>(),
max_idx.data<int>(),
max_val.data<float>(),
B, C, N, K); // 启动C个块,每个块包含B个线程,参数为data,index,max_index和max_value,batch_size,channels,number of points, number of neighbors
return max_idx;
}
三、总结
由于每个节点对应的原始点的个数不同,采用串行计算的效率不如采用并行计算的效率,因此设计了pytorch的CUDA扩展程序。调用顺序为:
./models/networks: index_max.forward_cuda()
./models/index_max_ext/index_max.cpp index_max_forward_cuda_wrapper()
./models/index_max_ext/index_max.cpp index_max_forward_cuda()
./models/index_max_ext/index_max_cuda.cu index_max_forward_cuda()
./models/index_max_ext/index_max_cuda.cu index_max_forward_cuda_kernel<<<C,B>>>(data.data<float>(),index.data<int>(),max_idx.data<int>(),max_val.data<float>(),B, C, N, K);
最后的函数是有__global__标志的,表明是在设备而非主机上执行的。