a = torch.Tensor([[1,2],[3,4]])
print(a)
a = torch.gather(a,1,index=torch.tensor([[0,0],[1,0]]))
print(a)
a = torch.gather(a,0,index=torch.tensor([[0,0],[1,0]]))
print(a)
a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])),其中, dim=1,表示的是在第二维度上操作。
index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,下标0对应的元素是1,所以第一行是[1,1]; [1,0]就是第二行对应元素的下标,下标1对应的元素,下标对应的元素是,所以返回的的是[4,3]。
torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
示例
结果为:
torch.gather(input, dim, index, out=None)
中的dim表示的就是第几维度,在这个二维例子中,如果dim=0,那么它表示的就是你接下来的操作是对于第一维度进行的,也就是行;如果dim=1,那么它表示的就是你接下来的操作是对于第二维度进行的,也就是列。index的大小和input的大小是一样的,他表示的是你所选择的维度上的操作。a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
,其中, dim=1,表示的是在第二维度上操作。index = torch.LongTensor([[0,0],[1,0]])
,[0,0]就是第一行对应元素的下标,下标0对应的元素是1,所以第一行是[1,1]; [1,0]就是第二行对应元素的下标,下标1对应的元素,下标对应的元素是,所以返回的的是[4,3]。