访问元素的计算方法
假设我们有一个三维的张量 gates
,它的维度是 batch_size x 3 x state_size
,我们需要在CUDA内核中访问某个特定位置的元素 gates[n][row][column]
。为了实现这一点,我们需要知道每个维度的跨度(stride),并使用一些简单的算术运算来计算出该元素的内存地址。
假设:
-
batch_size
是第一个维度的大小,跨度为3 * state_size
-
row
是第二个维度的大小,跨度为state_size
-
index
是第三个维度的大小,跨度为1
要访问元素 gates[n][row][column]
,可以使用以下算术运算:
gates.data<scalar_t>()[n * 3 * state_size + row * state_size + column]
问题
这种方法虽然高效,但有几个缺点:
- 冗长:表达式很长,降低了代码的可读性。
- 依赖显式传递stride:需要在内核函数的参数中显式传递stride,这在处理多个大小不同的张量时会导致参数列表非常长。
使用Accessors来简化
为了简化这些操作,PyTorch 提供了 accessors
,可以在CUDA内核中更方便地访问张量元素。下面是如何使用accessors的示例:
1. 定义CUDA内核函数
首先,我们定义一个使用accessors的CUDA内核函数:
#include <torch/extension.h>
template <typename scalar_t>
__global__ void example_kernel(
torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> gates,
int state_size) {
const int n = blockIdx.x;
const int row = threadIdx.y;
const int column = threadIdx.x;
if (n < gates.size(0) && row < gates.size(1) && column < gates.size(2)) {
scalar_t value = gates[n][row][column];
gates[n][row][column] = value * 2; // 简单地将每个元素乘以2
}
}
在这个内核函数中:
-
torch::PackedTensorAccessor32
是一个辅助类,用于在内核中访问张量数据。它接受三个参数:数据类型、张量维度和指针类型。 -
gates[n][row][column]
直接访问张量元素,代码可读性大大提高。
2. 定义C++前向函数并使用 AT_DISPATCH_ALL_TYPES
调度数据类型
#include <torch/extension.h>
#include <vector>
std::vector<torch::Tensor> example_forward(torch::Tensor input) {
const auto size = input.size(0);
auto output = torch::zeros_like(input);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
AT_DISPATCH_ALL_TYPES(input.scalar_type(), "example_forward_cuda", ([&] {
example_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
input.size(2));
}));
return {output};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &example_forward, "Example forward");
}
在这个C++前向函数中:
-
input.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>()
创建一个PackedTensorAccessor32
对象,用于在CUDA内核中访问三维张量数据。 - 使用
AT_DISPATCH_ALL_TYPES
调度数据类型。
3. 在Python中调用
最后,我们在Python中加载和调用这个扩展:
import torch
from torch.utils.cpp_extension import load
# JIT编译并加载C++扩展
example_cpp = load(name="example_cpp", sources=["example.cpp", "example_kernel.cu"], verbose=True)
# 创建输入张量
input = torch.randn(32, 3, 128, device='cuda', dtype=torch.float32)
# 调用前向传播函数
output = example_cpp.forward(input)
print(output)