CHH3213 / Note-Ubuntu_CHH3213

ubuntu系统CHH3213上做的笔记
0 stars 0 forks source link

torch.gather() #69

Open CHH3213 opened 2 years ago

CHH3213 commented 2 years ago

torch.gather(input, dim, index, out=None) → Tensor

沿给定轴dim,将输入索引张量index指定位置的值进行聚合。

a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])),其中, dim=1,表示的是在第二维度上操作。 index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,下标0对应的元素是1,所以第一行是[1,1]; [1,0]就是第二行对应元素的下标,下标1对应的元素,下标对应的元素是,所以返回的的是[4,3]。

CHH3213 commented 2 years ago

参考https://blog.csdn.net/Lucky_Rocks/article/details/79676095