SO-Net中的index_max的功能及具体实现

一、背景

  在SO-Net的分类模型的Encoder部分,first_pointnet给出特征向量后,有一个求索引的操作:

		M = node.size()[2]
        with torch.cuda.device(self.first_pn_out.get_device()):
            gather_index = index_max.forward_cuda(self.first_pn_out.detach(),
                                                  min_idx.int(),
                                                  M).detach().long()

这里的index_max在./models/index_max_ext文件下,是pytorch的CUDA扩展,接口为C++接口。因为没有CUDA编程基础,只好找了相关的书籍,简单了解一下基本知识,看了相关的博客和帖子,逐渐明白了这个扩展程序的功能是什么。

二、具体实现

2.1 文件结构

  实现一个pytorch的CUDA扩展,并以C++为接口,至少需要三个部分,setup.py,.cu以及.cpp文件,具体到index_max上就是setup.py,index_max.cpp和index_max_cuda.cu文件,下面一一介绍这三个文件具体的作用。

2.2 setup.py

  这个文件是为了编译后面的文件用的。其具体代码如下:

import setuptools
import torch
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension


setup(name='index_max',  # 编译后的链接库名称
      ext_modules=[CUDAExtension('index_max', ['index_max.cpp', 'index_max_cuda.cu'])],  # 待编译文件以及编译函数
      cmdclass={'build_ext': BuildExtension})  # 执行编译命令设置

2.3 index_max.cpp

  在原来的文件中有些部分没有使用,就不介绍了,直接解释用到的部分。在.cpp文件中声明了.cu中用到的函数,实现函数具体功能。

#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <thread>

torch::Tensor index_max_forward_cpu(const torch::Tensor data,const torch::Tensor index,const int K) 
{
	int B = data.size(0);
	int C = data.size(1);
	int N = data.size(2);
	
	torch::Tensor max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32).requires_grad(false));
	torch::Tensor max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32)) * -1000.0;
	
	auto data_a = data.accessor<float, 3>();
    auto index_a = index.accessor<int, 2>();
    auto max_idx_a = max_idx.accessor<int, 3>();
    auto max_val_a = max_val.accessor<float, 3>();

	for (int b=0; b<B; ++b) 
	{
		for (int c=0; c<C; ++c) 
		{
			for (int n=0; n<N; ++n) 
			{
				int k = index_a[b][n];
				float data_point = data_a[b][c][n];
				if (data_point > max_val_a[b][c][k]) 
				{
					max_val_a[b][c][k] = data_point;
					max_idx_a[b][c][k] = n;
				}
			}
		}
	}
    return max_idx;
}

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor/variable")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor index_max_forward_cuda_wrapper(const torch::Tensor data,const torch::Tensor index,const int K){
	CHECK_INPUT(data);
	CHECK_INPUT(index);

    return index_max_forward_cuda(data, index, K);
}

//绑定
PYBIND11_MODULE(index_max, m)
{
    m.def("forward_cuda", &index_max_forward_cuda_wrapper, "CUDA code without shared memory");
}

  这个函数传入的参数为:

self.first_pn_out.detach()  # [B,C,kN]
min_idx.int():  # [B,kN]
M  # 节点数

  程序中,k=index_a[b][n],表示某个batch中某个点的某个近邻节点的索引,data_point是与之对应的那个节点的某一维度的特征值,如果该特征值大于已有的值,则更新该值,并记录对应的索引。这等于:将以每个节点为近邻节点的的所有点(每个节点对应的点数量不同)进行max_pooling,举个例子,假设原始点云编号从0—1024,节点编号从0—64。对于1号节点,有编号为0,123, 589, 1024共四个点以其为近邻节点。那么该函数的作用就是对这四个节点的特征向量做一次max_pooling,其结果作为0号节点的特征。只不过在程序中返回了最大值对应的点的索引。结合后面的代码可以看出,该函数的作用就是对于每个节点的对应的原始点的特征向量做一次max_pooling,其输出作为新的节点特征向量。

self.first_pn_out_masked_max = self.first_pn_out.gather(dim=2, index=gather_index * mask_row_max.unsqueeze(1).long())  # BxCxM

2.4 index_max_cuda.cu

  CUDA编程其目的是为程序的运行进行加速,提高效率,代码的主要部分就是.cpp文件中具体功能的加速版本。

#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>

__global__ void index_max_forward_cuda_kernel(const float* __restrict__ data,
		const int* __restrict__ index,
		int* __restrict__ max_idx,
		float* __restrict__ max_val,
		const int B, const int C, const int N, const int K){
	int b = threadIdx.x;  //线程id 0 ~ B-1
	int c = blockIdx.x;  //块id 0 ~ C-1

	for(int n=0;n<N;++n){
		int k = index[b*N+n];
		float data_point = data[b*C*N+c*N+n];
		if (data_point > max_val[b*C*K+c*K+k]){
			max_val[b*C*K+c*K+k] = data_point;
			max_idx[b*C*K+c*K+k] = n;
		}
	}
}

torch::Tensor index_max_forward_cuda(const torch::Tensor data, const torch::Tensor index, const int K){
	int B = data.size(0);  // batch_size
	int C = data.size(1);  // feature channels
	int N = data.size(2);  // number of points

	auto device_idx = data.device().index();
        auto max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_idx));  // index
        auto max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_idx)) * -1000.0; // value

	index_max_forward_cuda_kernel<<<C, B>>>(data.data<float>(),
			index.data<int>(),
			max_idx.data<int>(),
			max_val.data<float>(),
			B, C, N, K);  // 启动C个块,每个块包含B个线程,参数为data,index,max_index和max_value,batch_size,channels,number of points, number of neighbors

	return max_idx;
}

三、总结

  由于每个节点对应的原始点的个数不同,采用串行计算的效率不如采用并行计算的效率,因此设计了pytorch的CUDA扩展程序。调用顺序为:

./models/networks:                          index_max.forward_cuda()
./models/index_max_ext/index_max.cpp        index_max_forward_cuda_wrapper()
./models/index_max_ext/index_max.cpp        index_max_forward_cuda()
./models/index_max_ext/index_max_cuda.cu    index_max_forward_cuda()
./models/index_max_ext/index_max_cuda.cu    index_max_forward_cuda_kernel<<<C,B>>>(data.data<float>(),index.data<int>(),max_idx.data<int>(),max_val.data<float>(),B, C, N, K);

最后的函数是有__global__标志的,表明是在设备而非主机上执行的。

上一篇:Java语言实现ALGO-1 区间k大数查询 排序 查找 (算法训练)


下一篇:el表达式获取cookie