torch.gather()
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
参数解释:
-
input (Tensor) – the source tensor
-
dim (int) – the axis along which to index
-
index (LongTensor) – the indices of elements to gather
-
sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.
-
out (Tensor, optional) – the destination tensor
示例1:
t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
解释:
gather的意思是聚集和取,即从input
这个张量中取元素,而index
则对应所取元素的下标。如果dim=0,那么index
中的数值表示行坐标,如果dim=1,那么index
中的数值表示列坐标。另外,index的shape和output的shape应该要一致。
以上述示例来说就是:index的第一行对应输出的第一行,其元素[0,0]就是从t中的第一行的下标为0的位置取其元素
示例2:
t = torch.tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
torch.gather(t, 0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]))
tensor([[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
torch.scatter()
torch.scatter_(input, dim, index, src) → Tensor
参数解释:
-
dim (int) – the axis along which to index
-
index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
-
src (Tensor or float) – the source element(s) to scatter.
-
reduce (str, optional) – reduction operation to apply, can be either ‘add’ or ‘multiply’.
示例1:
x = torch.rand(2, 5)
x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
解释:
scatter
可以理解为gather
的反操作,即用src
中的元素去替换input
中的元素,而index
中的数值则对应input
元素的下标。如果dim=0,那么index
中的数值表示横坐标,如果dim=1,那么index
中的数值表示纵坐标。另外,output的shape和input的shape是一致的。
src = torch.arange(1, 11).reshape((2, 5))
src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
------------------------------------------------
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
------------------------------------------------
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
参考链接:
https://zhuanlan.zhihu.com/p/187401278
https://www.cnblogs.com/dogecheng/p/11938009.html
https://wmathor.com/index.php/archives/1457/