1. 问题描述
在学习pytorch中张量的拼接时,又遇到类似的问题了:
rand_tensor=torch.rand((3,3))
t1=torch.cat((rand_tensor,rand_tensor,rand_tensor),dim=1)
t2=torch.cat((rand_tensor,rand_tensor,rand_tensor),dim=0)
t3=torch.stack((rand_tensor,rand_tensor,rand_tensor))
输出的结果是:
>> rand_tensor
tensor([[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445]])
>> t1 (cat dim=1)
tensor([[0.2129, 0.1800, 0.5501, 0.2129, 0.1800, 0.5501, 0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963, 0.8459, 0.7366, 0.6963, 0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445, 0.0661, 0.9675, 0.1445, 0.0661, 0.9675, 0.1445]])
>> t2 (cat dim=0)
"""
其实整体是水平合并的,只是写成了竖着的形式,把[0.2129, 0.1800, 0.5501],这种当作一个整体
其实是水平的
"""
tensor([[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445],
[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445],
[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445]])
>> t3 (stack 默认dim=0)
tensor([[[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445]],
[[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445]],
[[0.2129, 0.1800, 0.5501],
[0.8459, 0.7366, 0.6963],
[0.0661, 0.9675, 0.1445]]])
记得以前看过一个讲的非常清晰的博客,但是没有记录,awsl!
浏览了很多博客,大部分都是讲解np.sum(a1, axis=0)
中的axis
这个参数,轴。关键就一句话:
axis=i,则numpy沿着第i个下标变化的方向进行操作
axis=正值,从最外层往里剥开计算
axis=负值,从最里层往外剥开计算
2. dim/axis的理解(以sum函数为例)
-
pytorch中的dim
其实就是numpy中的axis
,二者本质都是表示维度
, - 在 NumPy中的维度Axis详解中——(有人将ndim属性叫维度,将axis叫轴,我还是习惯将axis称之为维度,axis=0称为第一个维度)
理解axis的直观动画解释,搬运自 理解numpy中的axis, python中的dim这篇文章。
>> b = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
])
>> print(b.shape)
torch.Size([3, 2, 3])
当dim=0时,torch.sum(b, dim=0)
,其结果和合并方式如下(csdn不会插入gif图,请自行移步到原文!