Pytorch如何取出特定维的元素

如何沿着某一维取出所有元素:
torch.gather(input, dim, index)
# 可以沿着某一维将需要的元素都取出来
# 一般input和index的维度除了要操作的那一维,其他维都是相同的
# 举例:
input = torch.LongTensor([[[1,2],[3,4],[5,6]],[[7,8],[9,10],[11,12]]])
# shape = [2, 3, 2]
index = torch.LongTensor([[0,1,0], [0,0,1]])
# shape = [2, 3]
input.gather(dim=2, index=index)
# [[1, 4, 5], [7, 9, 12]]
如何取出特定维的元素:
tensor = torch.Tensor([[1,2,3],[[4,5,6],[7,8,9]]) 
tensor[[[0,1],[0,1]],[[1,2],[0,2]]] # [[2,6],[1,6]] 
tensor[[[0,1],[0,1]],0]] # [[1,4],[1,4]] 
# 逗号分隔开的分别对应了不同的维度要如何往外取 
# 因此逗号两遍可以是两个矩阵,但是必须要维度相同(或者可以广播)

 

 
上一篇:30 Day Challenge Day 22 | Leetcode 85. Maximal Rectangle


下一篇:C++题目分享之重新排序递增递减类