Pytorch中torch.Tensor.scatter_用法

首先看一下这个函数的接口,需要三个输入:1)维度dim 2)索引数组index 3)原数组src,为了方便理解,我们后文把src换成input表示。最终的输出是新的output数组。

下面依次介绍:

1)维度dim:整数,可以是0,1,2,3…

2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置

3)原数组input:也是一个tensor,其中的数据类型任意

先说一下这个函数是干嘛的,在我看来,这个scatter函数就是把input数组中的数据进行重新分配。index中表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充0。

比如说下面这段代码:

import torch
 
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

运行结果如下:

tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])

下面,我详细说一下为什么会是这样的结果。

前面说了,scatter是input数组,根据index数组,对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。

input:

tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])

index:

index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output:

tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])

首先,对input[0][0]进行重分配。符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。

input[0][0] -> output[0][index[0][0]] = output[0][3]

数据位置发生的变化都是在第1维上,第0维不变。

input[0][1] -> output[0][index[0][1]] = output[0][1]

input[0][2] -> output[0][index[0][2]] = output[0][2]

input[0][3] -> output[0][index[0][3]] = output[0][0]

Pytorch中torch.Tensor.scatter_用法
需要注意的是,

为了方便理解,我们是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0。

一般scatter用于生成onehot向量,如下所示:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

输出结果是:

tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])

如果input是一个数字的话,代表这用于分配到output的数字是多少。

import torch

tensorB = torch.tensor([[2.5880, 2.1556, -31.0650, -13.5238, 11.0284],
                        [-0.2982, 10.8633, -22.4874, -9.2778, -1.1321]])

tensorA = torch.tensor([[0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.]])

index = torch.tensor([
    [0, 1, 2, 0, 0],
    [2, 0, 0, 1, 2]
])

tensorC = tensorA.scatter_(0, index, tensorB)  # dim=0: 按列填充

print('tensorC = ', tensorC)

# tensorC =  tensor([[  2.5880,  10.8633, -22.4874, -13.5238,  11.0284,   0.0000],
#                    [  0.0000,   2.1556,   0.0000,  -9.2778,   0.0000,   0.0000],
#                    [ -0.2982,   0.0000, -31.0650,   0.0000,  -1.1321,   0.0000]])

tensorD = tensorA.scatter_(1, index, tensorB)  # dim=1: 按行填充

print('tensorD = ', tensorD)

# tensorD =  tensor([[ 11.0284,   2.1556, -31.0650, -13.5238,  11.0284,   0.0000],
#                    [-22.4874,  -9.2778,  -1.1321,  -9.2778,   0.0000,   0.0000],
#                    [ -0.2982,   0.0000, -31.0650,   0.0000,  -1.1321,   0.0000]])



参考资料:
官方TORCH.TENSOR.SCATTER_
pytorch中torch.Tensor.scatter用法
one hot编码:torch.Tensor.scatter_()函数用法详解

上一篇:pytorch学习笔记


下一篇:Pytorch创建Tensor的几种方式详解(转载)