visinf / n3net

Neural Nearest Neighbors Networks (NIPS*2018)
Other
283 stars 45 forks source link

Indexes of nearest neighbor #13

Open alex-kharlamov opened 5 years ago

alex-kharlamov commented 5 years ago

This code

N = 7
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nn = N3AggregationBase(5, temp_opt={"external_temp": False})
nn.cuda()
nn.nnn.log_temp_bias = -50 # decrease temperature -> NNN acts more like hard kNN

# x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = torch.tensor(list(range(N)), dtype=torch.float, requires_grad=True)
n = torch.zeros_like(x).normal_() * 0.0001
x = x+n
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

produces output z with shape torch.Size([1, 7, 1, 5]). If input embeddings have more than 1 feature, third dimension of output would be changed.

How to get global indexes of nearest neighbors? For example, if we have XE with shape [1, 10, 5], and YE with shape [1, 3, 5], i want to have output indexes with shape [1, 3], just as indexes of nearest neighbors in KNeighborsClassifier.