# current_argmax: [batch, 7702]; self.layer_neigh: [7702, 20]
current_argmax = torch.argmax(weights_from_net, axis=2)
idx = torch.stack([torch.index_select(self.layer_neigh, 1, current_argmax[i])[0] for i in range(self.bs)])
we want to get idx of the vertex that weights most in each vertex's predefined 20-neighbours. But torch.index_select(self.layer_neigh, 1, current_argmax[i])[0] can only get the idx from 20-neighbours of the first vertex and this makes no sense. In other words, torch.index_select(self.layer_neigh, 1, current_argmax[i])[0] returns tensor with size [7702, 7702], and your code does select [0, :], while the truth should be [[0,0],[1,1],[2,2]...[7701,7701]]. This is an example from torch1.4 doc for this api:
167 line of trainer/parsernet.py:
we want to get idx of the vertex that weights most in each vertex's predefined 20-neighbours. But
torch.index_select(self.layer_neigh, 1, current_argmax[i])[0]
can only get the idx from 20-neighbours of the first vertex and this makes no sense. In other words,torch.index_select(self.layer_neigh, 1, current_argmax[i])[0]
returns tensor with size [7702, 7702], and your code does select [0, :], while the truth should be [[0,0],[1,1],[2,2]...[7701,7701]]. This is an example from torch1.4 doc for this api: