使用 accessors简化在CUDA内核中访问张量数据的代码

访问元素的计算方法

假设我们有一个三维的张量 gates,它的维度是 batch_size x 3 x state_size,我们需要在CUDA内核中访问某个特定位置的元素 gates[n][row][column]。为了实现这一点,我们需要知道每个维度的跨度(stride),并使用一些简单的算术运算来计算出该元素的内存地址。

假设:

  • batch_size 是第一个维度的大小,跨度为 3 * state_size
  • row 是第二个维度的大小,跨度为 state_size
  • index 是第三个维度的大小,跨度为 1

要访问元素 gates[n][row][column],可以使用以下算术运算:

gates.data<scalar_t>()[n * 3 * state_size + row * state_size + column]

问题

这种方法虽然高效,但有几个缺点:

  1. 冗长:表达式很长,降低了代码的可读性。
  2. 依赖显式传递stride:需要在内核函数的参数中显式传递stride,这在处理多个大小不同的张量时会导致参数列表非常长。

使用Accessors来简化

为了简化这些操作,PyTorch 提供了 accessors,可以在CUDA内核中更方便地访问张量元素。下面是如何使用accessors的示例:

1. 定义CUDA内核函数

首先,我们定义一个使用accessors的CUDA内核函数:

#include <torch/extension.h>

template <typename scalar_t>
__global__ void example_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> gates,
    int state_size) {
  const int n = blockIdx.x;
  const int row = threadIdx.y;
  const int column = threadIdx.x;

  if (n < gates.size(0) && row < gates.size(1) && column < gates.size(2)) {
    scalar_t value = gates[n][row][column];
    gates[n][row][column] = value * 2; // 简单地将每个元素乘以2
  }
}

在这个内核函数中:

  • torch::PackedTensorAccessor32 是一个辅助类,用于在内核中访问张量数据。它接受三个参数:数据类型、张量维度和指针类型。
  • gates[n][row][column] 直接访问张量元素,代码可读性大大提高。
2. 定义C++前向函数并使用 AT_DISPATCH_ALL_TYPES 调度数据类型
#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> example_forward(torch::Tensor input) {
  const auto size = input.size(0);
  auto output = torch::zeros_like(input);

  const int threads = 1024;
  const int blocks = (size + threads - 1) / threads;

  AT_DISPATCH_ALL_TYPES(input.scalar_type(), "example_forward_cuda", ([&] {
    example_kernel<scalar_t><<<blocks, threads>>>(
        input.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
        input.size(2));
  }));

  return {output};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &example_forward, "Example forward");
}

在这个C++前向函数中:

  • input.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>() 创建一个 PackedTensorAccessor32 对象,用于在CUDA内核中访问三维张量数据。
  • 使用 AT_DISPATCH_ALL_TYPES 调度数据类型。
3. 在Python中调用

最后,我们在Python中加载和调用这个扩展:

import torch
from torch.utils.cpp_extension import load

# JIT编译并加载C++扩展
example_cpp = load(name="example_cpp", sources=["example.cpp", "example_kernel.cu"], verbose=True)

# 创建输入张量
input = torch.randn(32, 3, 128, device='cuda', dtype=torch.float32)

# 调用前向传播函数
output = example_cpp.forward(input)

print(output)
上一篇:Oracle数据库 ASH视图详解


下一篇:C++——类和对象(中)