visinf / n3net

Neural Nearest Neighbors Networks (NIPS*2018)
Other
283 stars 45 forks source link

Aggregate selects same element multiple times #9

Open LemonPi opened 5 years ago

LemonPi commented 5 years ago

I'm trying to evaluate NNN against conventional KNN on a simple test case. The test case is to find the 5 nearest neighbour for a permutation of indices (for easy intuitive verification). The problem is that the aggregate output is outputting the same value for all 5 neighbours.

Problem setup:

import torch
import non_local
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})

x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

Where the aggregate output z is

tensor([[[[10.0001, 10.0001, 10.0001, 10.0001, 10.0001]],
         [[42.0001, 42.0001, 42.0001, 42.0001, 42.0001]],
         [[22.0001, 22.0001, 22.0001, 22.0001, 22.0001]],
...

Is this supposed to be the case and I'm interpreting the result wrong? If so then what is the aggregate output z supposed to represent?

tobiasploetz commented 5 years ago

Hi @LemonPi ,

I think this is a problem of symmetry. In a nutshell, our continuous relaxation of hard kNN selection is not good in breaking ties. If you take the index 10, then the indices 9/11, 8/12, 7/13, ..., are each equally distant to 10 and hence contribute with equal weight to the neighbor selection. Also the logits are updated equally (Eq. 9) and hence in the next round of neighbor selection they again have equal weights.

I think there is also some numerical instability involved in how Eq. 9 is implemented right now ... :)

LemonPi commented 5 years ago

I see, this seems to be a relatively big problem since even adding noise didn't resolve tie-breaking (tried rand up to magnitudes up to 1).

tobiasploetz commented 5 years ago

Hi @LemonPi ,

there was an issue with numerical stability in when computing log(1 - exp(x)). This should be fixed with the latest update.

I modified your example a bit (no permutation, decrease temperature):

import torch
import models.non_local as non_local
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})
nn.cuda()
nn.nnn.log_temp_bias = -50 # decrease temperature -> NNN acts more like hard kNN

# x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = torch.tensor(list(range(N)), dtype=torch.float, requires_grad=True)
n = torch.zeros_like(x).normal_() * 0.0001
x = x+n
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

for i in range(N):
    print("\t".join(["{:.2f}"]*5).format(*(z[0,i,0,:].tolist())))

This gives the following output:

0.00    1.00    2.00    3.00    4.00
1.00    2.00    0.00    3.00    4.00
2.00    1.00    3.00    0.00    4.00
3.00    4.00    2.00    5.00    1.00
4.00    3.00    5.00    2.00    6.00
5.00    6.00    4.00    7.00    3.00
6.00    7.00    5.00    8.00    4.00
7.00    6.00    8.00    5.00    9.00
8.00    7.00    9.00    6.00    10.00
9.00    10.00   8.00    11.00   7.00
10.00   11.00   9.00    12.00   8.00
11.00   10.00   12.00   9.00    13.00
12.00   13.00   11.00   14.00   10.00
13.00   12.00   14.00   15.00   11.00
14.00   15.00   13.00   12.00   16.00
15.00   14.00   16.00   13.00   17.00
16.00   17.00   15.00   18.00   14.00
17.00   16.00   18.00   15.00   19.00
18.00   19.00   17.00   20.00   16.00
19.00   20.00   18.00   21.00   17.00
20.00   19.00   21.00   18.00   22.00
21.00   22.00   20.00   19.00   23.00
22.00   21.00   23.00   20.00   24.00
23.00   24.00   22.00   25.00   21.00
24.00   25.00   23.00   26.00   22.00
25.00   24.00   26.00   23.00   27.00
26.00   25.00   27.00   24.00   28.00
27.00   28.00   26.00   29.00   25.00
28.00   27.00   29.00   26.00   30.00
29.00   30.00   28.00   31.00   27.00
30.00   31.00   29.00   32.00   28.00
31.00   30.00   32.00   29.00   33.00
32.00   31.00   33.00   34.00   30.00
33.00   34.00   32.00   35.00   31.00
34.00   33.00   35.00   32.00   36.00
35.00   36.00   34.00   33.00   37.00
36.00   37.00   35.00   38.00   34.00
37.00   38.00   36.00   39.00   35.00
38.00   37.00   39.00   36.00   40.00
39.00   40.00   38.00   41.00   37.00
40.00   41.00   39.00   42.00   38.00
41.00   40.00   42.00   39.00   43.00
42.00   43.00   41.00   40.00   44.00
43.00   42.00   44.00   41.00   45.00
44.00   45.00   43.00   46.00   42.00
45.00   44.00   46.00   43.00   47.00
46.00   47.00   45.00   92.00   49.00
47.00   94.00   49.00   45.00   44.00
48.00   49.00   47.00   46.00   45.00
49.00   48.00   47.00   46.00   45.00

I hope this solves your problem

LemonPi commented 5 years ago

I see, I tried this again and the critical line is decreasing the temperature. There is a new issue however with some hallucinated values such as:

42.00003 , 84.00015 , 43.999893, 39.999966, 39.000072

Where did the 84 come from?!?

This seems to be a phenomena that occurs when the temperature is too low:

image

From what I understand from the paper, lowering the temperature more closely approximates hard kNN and results in sharper distributions. What practical issues does this have?

tobiasploetz commented 5 years ago

I think this is related to Pytorch's implementation of log_softmax, which seemingly does not work correct if the maximal value of the argument has a large absolute value and appears multiple times:

F.log_softmax(torch.from_numpy(np.asarray([-1e2, -1e2], dtype=float)).float()).exp()
# tensor([0.5000, 0.5000])
F.log_softmax(torch.from_numpy(np.asarray([-1e20, -1e20], dtype=float)).float()).exp()
# tensor([1., 1.])
F.log_softmax(torch.from_numpy(np.asarray([1e20, 1e20], dtype=float)).float()).exp()
# tensor([1., 1.])

This causes the weights of the weighted averages to sum to something greater than one.

From a practical point of view this should be of minor relevance if you want to train the N3 block within your network since the gradient signal will vanish anyway the lower your temperature gets (in the limit of t->0 or log t -> -inf, N3 selection is just KNN selection and hence has zero gradients everywhere). Hence your network will probably never reach the above situation.

LemonPi commented 5 years ago

Thanks, will keep in mind!

LemonPi commented 5 years ago

The spurious values seems to be from IndexedMatmul2Efficient in aggregating the output (produce z from W) instead of from the log_softmax. For example, the W of querying for k=3 nearest neighbours of N = 30 gives:

array([[0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.36787945, 0.36787945],
       [1.        , 0.        , 0.        ],
       [0.        , 0.36787945, 0.36787945],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ]], dtype=float32)

Which seems correct - the distributions are concentrated around 22, 21, and 23. However the output z gives

array([21.999949, 16.18666 , 16.18666 ], dtype=float32)

Which doesn't correspond to the indices from W... The I is just arange(0,N) for each query point.

Everything in IndexMatmul2Efficient has no issues until

            z_interm = torch.cat([torch.matmul(y_full[:,i_k:i_k+1,:,:], x_interm) for i_k in range(k)], 1)

Which results in

z_interm[0,:,22]
Out[24]: 
tensor([[21.9999],
        [16.1867],
        [16.1867]], device='cuda:0')

Update, maybe the problem is in calculating W in the first place because the columns represent probability distributions so they should sum to 1. However in this case they do not...

LemonPi commented 5 years ago

Fixing this by normalizing each output distribution removes the spurious values, but still ends up selecting the center value multiple times because of aggregation via expected value. It seems like the fundamental cause of this is that the distributions are ordered and this method is not relaxing kNN since in kNN we don't care about the order of neighbours. A more direct relaxation would give 1 distribution per query point instead of k distributions.

We can do normalization by adding the following at the end of NeuralNearestNeighbors.forward

        # normalize so output is a distribution
        for bb in range(b):
            for mm in range(m):
               W[bb, mm] /= torch.sum(W[bb, mm], dim=0)

Example of problem before normalization with N = 3000 data points and k=3 t_too_low_3000 Average error to kNN neighbourhoods (0 is exact) t_too_low_3000b

Same problem after normalization fix_t_too_low_3000 fix_t_too_low_3000b