函数定义
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Gathers values along an axis specified by dim.
对于一个3-D的张量,输出按照以下公式被指定为:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
函数参数
-
input (Tensor) – the source tensor
-
dim (int) – the axis along which to index
-
index (LongTensor) – the indices of elements to gather
-
sparse_grad (bool, optional) – If
True
, gradient w.r.t.input
will be a sparse tensor. - out (Tensor, optional) – the destination tensor
函数参数说明
- 参数input和参数index必须拥有相同数量的维度,并且要求index.size(d) <= input.size(d)对于所有的维度d != dim。
- out将会拥有和index一样的形状。
- 参数input和参数index不能彼此进行广播
例子
>>> t = torch.tensor([[1, 2], [3, 4]]) >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]])