thuml / HashNet

Code release for "HashNet: Deep Learning to Hash by Continuation" (ICCV 2017)
MIT License
241 stars 84 forks source link

test.py for CIFAR #42

Closed rio26 closed 4 years ago

rio26 commented 4 years ago

I have tried to edit the code for CIFAR dataset but the result I got seems highly unreasonable (the MAP value was greater >7).

The original code you provided was based on NUS-WIDE dataset, which is a multi-labelled benchmark. I believe I made some mistakes on editing it but I really can't fix the bug. I have been stuck on this for days...

Could you let me know the part of code for "def mean_average_precision(...): "? Or could you point out which part of the following codes are wrong?

def mean_average_precision(params, R): database_code = params['database_code'] validation_code = params['test_code'] database_labels = params['database_labels'] validation_labels = params['test_labels'] query_num = validation_code.shape[0]

sim = np.dot(database_code, validation_code.T)
ids = np.argsort(-sim, axis=0)
APx = []

for i in range(query_num):
    label = validation_labels[i, :]             # I changed this line
    if label == 0:           label = -1            # I changed this line
    idx = ids[:, i]
    imatch = np.sum(database_labels[idx[0:R], :] == label, axis=1) > 0                 # I changed this line
    relevant_num = np.sum(imatch)
    Lx = np.cumsum(imatch)
    Px = Lx.astype(float) / np.arange(1, R+1, 1)
    if relevant_num != 0:
        APx.append(np.sum(Px * imatch) / relevant_num)

return np.mean(np.array(APx))
rio26 commented 4 years ago

Problem fixed. I was using the pytorch built-in cifar dataset. I installed another one and read into it. Now I can run it, although I could only MAP=0.29. I'll try different backbone models later on.