numpy中关于数组维度的理解——dim和axis

1. 问题描述

在学习pytorch中张量的拼接时,又遇到类似的问题了:

rand_tensor=torch.rand((3,3))
t1=torch.cat((rand_tensor,rand_tensor,rand_tensor),dim=1)
t2=torch.cat((rand_tensor,rand_tensor,rand_tensor),dim=0)
t3=torch.stack((rand_tensor,rand_tensor,rand_tensor))

输出的结果是:

>> rand_tensor
tensor([[0.2129, 0.1800, 0.5501],
        [0.8459, 0.7366, 0.6963],
        [0.0661, 0.9675, 0.1445]])
>> t1 (cat dim=1)
tensor([[0.2129, 0.1800, 0.5501, 0.2129, 0.1800, 0.5501, 0.2129, 0.1800, 0.5501],
        [0.8459, 0.7366, 0.6963, 0.8459, 0.7366, 0.6963, 0.8459, 0.7366, 0.6963],
        [0.0661, 0.9675, 0.1445, 0.0661, 0.9675, 0.1445, 0.0661, 0.9675, 0.1445]])
>> t2 (cat dim=0)
"""
其实整体是水平合并的,只是写成了竖着的形式,把[0.2129, 0.1800, 0.5501],这种当作一个整体
其实是水平的
"""
tensor([[0.2129, 0.1800, 0.5501],
        [0.8459, 0.7366, 0.6963],
        [0.0661, 0.9675, 0.1445],
        [0.2129, 0.1800, 0.5501],
        [0.8459, 0.7366, 0.6963],
        [0.0661, 0.9675, 0.1445],
        [0.2129, 0.1800, 0.5501],
        [0.8459, 0.7366, 0.6963],
        [0.0661, 0.9675, 0.1445]])
>> t3 (stack 默认dim=0)
tensor([[[0.2129, 0.1800, 0.5501],
         [0.8459, 0.7366, 0.6963],
         [0.0661, 0.9675, 0.1445]],

        [[0.2129, 0.1800, 0.5501],
         [0.8459, 0.7366, 0.6963],
         [0.0661, 0.9675, 0.1445]],

        [[0.2129, 0.1800, 0.5501],
         [0.8459, 0.7366, 0.6963],
         [0.0661, 0.9675, 0.1445]]])

记得以前看过一个讲的非常清晰的博客,但是没有记录,awsl!

浏览了很多博客,大部分都是讲解np.sum(a1, axis=0)中的axis这个参数,轴。关键就一句话:

axis=i,则numpy沿着第i个下标变化的方向进行操作
axis=正值,从最外层往里剥开计算
axis=负值,从最里层往外剥开计算

2. dim/axis的理解(以sum函数为例)

  • pytorch中的dim其实就是numpy中的axis,二者本质都是表示维度
  • NumPy中的维度Axis详解中——(有人将ndim属性叫维度,将axis叫轴,我还是习惯将axis称之为维度,axis=0称为第一个维度)

理解axis的直观动画解释,搬运自 理解numpy中的axis, python中的dim这篇文章。

>> b = torch.tensor([
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
])
>> print(b.shape)
torch.Size([3, 2, 3])

当dim=0时,torch.sum(b, dim=0),其结果和合并方式如下(csdn不会插入gif图,请自行移步到原文!

上一篇:13-3 合并内容相同的连续单元格


下一篇:stm32 中断 和事件