神奇的torch.gather()

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor 

Gathers values along an axis specified by dim.

例如 原本一个tensor a是:

a[0][0] a[0][1]
a[1][0] a[1][1]


index tensor是:

 

 

j k
m n

现在,b=torch.gather(input=a, dim=0, index=index)

因此,将第0维的数据替换成index的数据,则b是:

a[j][0] a[k][1]
a[m][0] a[n][1]

如果,b=torch.gather(input=a, dim=1, index=index)

那么,b将会是

a[0][j] a[0][k]
a[1][m] a[1][n]

总之,dim是多少,就将那一维所查询的位置换成index里面对应位置上的数

上一篇:R语言长宽数据转换函数tidyr包


下一篇:内置对象——String