gusye1234 / LightGCN-PyTorch

The PyTorch implementation of LightGCN
870 stars 229 forks source link

recall and precision #38

Open zbt78 opened 1 year ago

zbt78 commented 1 year ago

def RecallPrecision_ATk(test_data, r, k): """ test_data should be a list? cause users may have different amount of pos items. shape (test_batch, k) pred_data : shape (test_batch, k) NOTE: pred_data should be pre-sorted k : top-k """ right_pred = r[:, :k].sum(1) precis_n = k recall_n = np.array([len(test_data[i]) for i in range(len(test_data))]) recall = np.sum(right_pred/recall_n) precis = np.sum(right_pred)/precis_n return {'recall': recall, 'precision': precis}

In this way, the result of recall may be greater than 1. Is there any mistakes?