-
reshape
:用于改变张量形状,可以处理非连续张量。 -
view
:用于改变张量形状,但只能用于连续张量。对于非连续张量,需要先调用.contiguous()
方法。 -
transpose
:用于交换两个维度。 -
permute
:用于根据指定顺序重新排列所有维度。 -
上述4个函数均共享内存,因此对上述函数的返回值进行修改原数据会受到影响,我们可以用下述代码来验证
代码:
import torch data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32) data1 = torch.tensor([9, 10, 11, 12, 13, 14, 15, 16], dtype=torch.float32) data2 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) data3 = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32) view = data.view((2, 4)) reshape = data1.reshape((2, 4)) transpose = data2.transpose(1, 0) permute = data3.permute(1, 0) print('修改前') print(data) print(data1) print(transpose) print(permute) view[0][0] = 100 reshape[1][1] = 100 transpose[0][0] = 100 permute[1][1] = 100 print('修改后') print(data) print(data1) print(data2) print(data3)
输出:
修改前 tensor([1., 2., 3., 4., 5., 6., 7., 8.]) tensor([ 9., 10., 11., 12., 13., 14., 15., 16.]) tensor([[1., 3.], [2., 4.]]) tensor([[5., 7.], [6., 8.]]) 修改后 tensor([100., 2., 3., 4., 5., 6., 7., 8.]) tensor([ 9., 10., 11., 12., 13., 100., 15., 16.]) tensor([[100., 2.], [ 3., 4.]]) tensor([[ 5., 6.], [ 7., 100.]])