深入理解 PyTorch 中的torch.stack函数:中英双语-中文版

深入理解 PyTorch 中的 torch.stack 函数

在使用 PyTorch 进行深度学习开发时,经常需要对张量进行操作和组合。torch.stack 是一个非常常用且重要的函数,它可以将一组张量沿着新的维度拼接成一个新的张量。本文将深入介绍 torch.stack 的用法,包括其功能、参数、注意事项和实际案例。


什么是 torch.stack

torch.stack 的主要作用是沿着新的维度将多个张量堆叠。与 torch.cat 不同,torch.cat 是在已有维度上进行拼接,而 torch.stack创建一个新的维度,将一组张量按照指定的位置堆叠起来。

官方文档定义如下:

torch.stack(tensors, dim=0) → Tensor
参数解释:
  • tensors:要堆叠的张量序列(可以是列表或元组),所有张量的形状必须相同。
  • dim:新维度的索引(位置),默认值为 0
返回值:

返回一个新的张量,包含输入张量序列,堆叠后的维度会比原张量多 1。


使用示例

基本用法
import torch

# 创建三个相同形状的张量
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])

# 沿着新的维度堆叠
result = torch.stack([t1, t2, t3], dim=0)
print(result)

输出:

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

上述代码中,dim=0 表示在最外层新增一个维度(行堆叠)。结果张量的形状是 [3, 3],表示 3 行 3 列。

改变维度位置

如果我们将 dim 设置为 1:

result = torch.stack([t1, t2, t3], dim=1)
print(result)

输出:

tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])

此时,dim=1 表示在列的维度堆叠,结果形状是 [3, 3],但数据的排列方式发生了变化。


torch.cat 的对比

torch.cattorch.stack 都可以用于张量的组合,但它们的功能和结果有显著区别。

区别 1:是否创建新维度
  • torch.cat:仅在现有维度上拼接,不会创建新的维度。
  • torch.stack:会创建一个新的维度。

例如:

# 使用 torch.cat
result_cat = torch.cat([t1.unsqueeze(0), t2.unsqueeze(0), t3.unsqueeze(0)], dim=0)
print(result_cat)

# 使用 torch.stack
result_stack = torch.stack([t1, t2, t3], dim=0)
print(result_stack)

输出:
两者的结果相同,但 torch.cat 需要手动添加维度(通过 unsqueeze),而 torch.stack 会自动处理这一点。

区别 2:维度要求
  • torch.cat 的输入张量在拼接维度以外的维度上必须完全一致。
  • torch.stack 的输入张量要求形状完全一致

进阶案例

1. 批量生成新的张量

假设我们有多个二维张量,想要将它们堆叠成一个三维张量:

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.tensor([[9, 10], [11, 12]])

result = torch.stack([t1, t2, t3], dim=0)
print(result)

输出:

tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]]])

此时,dim=0 表示在最外层添加一个维度,结果形状为 [3, 2, 2]

2. 张量拆分后再堆叠
# 创建一个三维张量
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

# 沿着 dim=0 拆分
split_tensors = torch.unbind(x, dim=0)

# 再次堆叠
result = torch.stack(split_tensors, dim=1)
print(result)

输出:

tensor([[[ 1,  5,  9],
         [ 3,  7, 11]],

        [[ 2,  6, 10],
         [ 4,  8, 12]]])

通过 torch.unbind 将张量拆分后,我们可以用 torch.stack 重新组织张量的结构。


注意事项

  1. 输入张量的形状必须完全一致
    如果输入张量的形状不同,将会报错:

    t1 = torch.tensor([1, 2])
    t2 = torch.tensor([3, 4, 5])
    torch.stack([t1, t2])  # 会报错
    
  2. dim 的取值范围

    • dim 的取值范围是 [-(d+1), d],其中 d 是输入张量的维度。
    • 如果设置超出范围的值,将会报错。
  3. 效率问题

    • torch.stack 本质上是向输入张量添加一个维度,然后调用 torch.cat 实现堆叠。因此,在大规模操作中,合理利用 torch.cat 和维度操作可能更高效。

总结

torch.stack 是一个强大而灵活的函数,在处理张量时提供了简洁的接口。通过为张量增加新的维度,它使得许多复杂的张量操作变得更加直观。无论是构建批量数据、改变张量维度,还是处理高级的张量变换,torch.stack 都是一个不可或缺的工具。

在实际使用中,了解 torch.stack 的参数含义以及它与 torch.cat 的区别,可以帮助我们写出更加高效和简洁的代码。希望通过这篇博客,大家能够全面掌握 torch.stack 的使用!

上一篇:批量计算(Batch Processing)


下一篇:Python字符串常用操作-一、字符串的切片