lelan-li / SSAH

Self-Supervised Adversarial Hashing Networks for Cross-Modal Retrieval(CVPR2018)
http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_Self-Supervised_Adversarial_Hashing_CVPR_2018_paper.pdf
162 stars 51 forks source link

How to plot Precision-Recall curves on cross-modal hashing retrieval task? #5

Open Shen-Qiu opened 5 years ago

Shen-Qiu commented 5 years ago

P-R_curve I have plotted two Precision-Recall curves on the result of Flickr-25K. Some settings as followed: bit=16, using vgg19 features. The codes and result are provided in SHARING.

These curves look a bit strange because some of their starting points don't show a downward trend as those curves shown in DCMH and SSAH.

Is there a problem with my code?

Shen-Qiu commented 5 years ago

# -*- coding: utf-8 -*-
"""
    Plot precision-recall curve on the result of MIR-FLICKR-25K
"""
import numpy as np
import matplotlib.pyplot as plt 

def calc_hammingDist(B1, B2):
    q = B2.shape[1]
    disH = 0.5 * (q - np.dot(B1, B2.transpose()))
    return disH

def calc_similarity(label_1, label_2):
    return (np.dot(label_1, label_2.transpose()) > 0).astype(np.float32)

def calc_map(qB, rB, query_L, retrieval_L):
    # qB: {-1,+1}^{mxq}
    # rB: {-1,+1}^{nxq}
    # query_L: {0,1}^{mxl}
    # retrieval_L: {0,1}^{nxl}
    num_query = query_L.shape[0]
    map = 0
    for iter in range(num_query):
        gnd = (np.dot(query_L[iter, :], retrieval_L.transpose()) > 0).astype(np.float32)
        tsum = int(np.sum(gnd))
        if tsum == 0:
            continue
        hamm = calc_hammingDist(qB[iter, :], rB)
        ind = np.argsort(hamm)
        gnd = gnd[ind]
        count = np.linspace(1, tsum, tsum)

        tindex = np.asarray(np.where(gnd == 1)) + 1.0
        map = map + np.mean(count / (tindex))
    map = map / num_query
    return map

def cal_Precision_Recall_Curve(qB, rB, query_L, retrieval_L):
    S = calc_similarity(query_L, retrieval_L)
    dist = calc_hammingDist(qB, rB)
    num = qB.shape[0] # the number of input instances

    precision = np.zeros((num, bits + 1))
    recall = np.zeros((num, bits + 1))
    for i in range(num):
        relevant = set(np.where(S[i, :] == 1)[0])
        retrieved = set()
        for bit in range(bits + 1):
            retrieved = set(np.where(dist[i, :] == bit)[0]) | retrieved
            ret_rel = len(retrieved & relevant)
            #print('bit : {0}, Precision: {1:.4f}, Recall: {2:.4f}'.format(bit, 
            #      ret_rel / len(retrieved), ret_rel / len(relevant)))
            recall[i, bit] = ret_rel / len(relevant)
            if len(retrieved) == 0:
                continue
            precision[i, bit] = ret_rel / len(retrieved)

    return recall.mean(axis=0), precision.mean(axis=0)

result = np.load('./result_16bits_VGG19.npz')
#qBX = result['qBX'][0:1, :] # image query, just for one instance
qBX = result['qBX'] # image query
qBY = result['qBY'] # text query 
rBX = result['rBX'] # image retrieval 
rBY = result['rBY'] # text retrieval 
#query_L = result['query_L'][0:1, :] # query label, just for one instance
query_L = result['query_L'] # query label
retrieval_L = result['retrieval_L'] # retrieval label

mapi2t = result['mapi2t']
mapt2i = result['mapt2i']
print('mapi2t: {0:.4f}'.format(mapi2t))
print('mapt2i: {0:.4f}'.format(mapt2i))

bits = result['bit']
#calc_map(qBX, rBY, query_L, retrieval_L)
recall, precision = cal_Precision_Recall_Curve(qBX, rBY, query_L, retrieval_L)
fig = plt.figure(1)
ax = fig.add_subplot(121)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Image->Text')
plt.xlabel('Recall')
plt.ylabel('Precision')
#plt.plot()

# Text -> Image
recall, precision = cal_Precision_Recall_Curve(qBY, rBX, query_L, retrieval_L)
ax = fig.add_subplot(122)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Text->Image')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.plot()
anan1030 commented 1 year ago

Hi, you need to exclude the zeros in precision when doing precision.mean(axis=0), you can use np.average() for weighted average.