torch.gather(input, dim, index) # 可以沿着某一维将需要的元素都取出来 # 一般input和index的维度除了要操作的那一维,其他维都是相同的 # 举例: input = torch.LongTensor([[[1,2],[3,4],[5,6]],[[7,8],[9,10],[11,12]]]) # shape = [2, 3, 2] index = torch.LongTensor([[0,1,0], [0,0,1]]) # shape = [2, 3] input.gather(dim=2, index=index) # [[1, 4, 5], [7, 9, 12]]如何取出特定维的元素:
tensor = torch.Tensor([[1,2,3],[[4,5,6],[7,8,9]]) tensor[[[0,1],[0,1]],[[1,2],[0,2]]] # [[2,6],[1,6]] tensor[[[0,1],[0,1]],0]] # [[1,4],[1,4]] # 逗号分隔开的分别对应了不同的维度要如何往外取 # 因此逗号两遍可以是两个矩阵,但是必须要维度相同(或者可以广播)