eamid / trimap

TriMap: Large-scale Dimensionality Reduction Using Triplets
Apache License 2.0
304 stars 20 forks source link

error about the implement of the function rejection sample #13

Closed UnkownNames closed 3 years ago

UnkownNames commented 3 years ago

in the trimap_.py, there is a function

def rejection_sample(n_samples, max_int, rejects):
    """
    Samples "n_samples" integers from a given interval [0,max_int] while
    rejecting the values that are in the "rejects".

    """
    result = np.empty(n_samples, dtype=np.int32)
    for i in range(n_samples):
        reject_sample = True
        while reject_sample:
            j = np.random.randint(max_int)
            for k in range(i):
                if j == result[k]:
                    break
            for k in range(rejects.shape[0]):
                if j == rejects[k]:
                    break
            else:
                reject_sample = False
        result[i] = j
    return result

and another function

def sample_knn_triplets(P, nbrs, n_inliers, n_outliers):
    """
    Sample nearest neighbors triplets based on the similarity values given in P

    Input
    ------

    nbrs: Nearest neighbors indices for each point. The similarity values 
        are given in matrix P. Row i corresponds to the i-th point.

    P: Matrix of pairwise similarities between each point and its neighbors 
        given in matrix nbrs

    n_inliers: Number of inlier points

    n_outliers: Number of outlier points

    Output
    ------

    triplets: Sampled triplets
    """
    n, n_neighbors = nbrs.shape
    triplets = np.empty((n * n_inliers * n_outliers, 3), dtype=np.int32)
    for i in numba.prange(n):
        sort_indices = np.argsort(-P[i])
        for j in numba.prange(n_inliers):
            sim = nbrs[i][sort_indices[j + 1]]
            samples = rejection_sample(n_outliers, n, sort_indices[: j + 2])
            for k in numba.prange(n_outliers):
                index = i * n_inliers * n_outliers + j * n_outliers + k
                out = samples[k]
                triplets[index][0] = i
                triplets[index][1] = sim
                triplets[index][2] = out
                # if sim==out :
                #     print("sim==out")
    return triplets

the sort_indices is always range(0,150) [ set the n_inliners=100], in the raw implemention code you have guarantee that out is not in range(0,150), but in fact range(0,150) is not the true indice for sim, so I have found the indice of sim and out will be equal sometimes. in my opinion, the implemention of sample_knn_triplets should be below:

def sample_knn_triplets(P, nbrs, n_inliers, n_outliers):
    """
    Sample nearest neighbors triplets based on the similarity values given in P

    Input
    ------

    nbrs: Nearest neighbors indices for each point. The similarity values 
        are given in matrix P. Row i corresponds to the i-th point.

    P: Matrix of pairwise similarities between each point and its neighbors 
        given in matrix nbrs

    n_inliers: Number of inlier points

    n_outliers: Number of outlier points

    Output
    ------

    triplets: Sampled triplets
    """
    n, n_neighbors = nbrs.shape
    triplets = np.empty((n * n_inliers * n_outliers, 3), dtype=np.int32)
    for i in numba.prange(n):
        sort_indices = np.argsort(-P[i])
        for j in numba.prange(n_inliers):
            sim = nbrs[i][sort_indices[j + 1]]
           # I have changed the next line compared with the raw code
            samples = rejection_sample(n_outliers, n, nbrs[i][sort_indices[: j+2]])
            for k in numba.prange(n_outliers):
                index = i * n_inliers * n_outliers + j * n_outliers + k
                out = samples[k]
                triplets[index][0] = i
                triplets[index][1] = sim
                triplets[index][2] = out
                # if sim==out :
                #     print("sim==out")
    return triplets
eamid commented 3 years ago

Great catch! You are right, nbrs[i][sort_indices[: j+2]] should be excluded from the samples. Thank you for finding this! Can you please send a pull request for this change?

eamid commented 3 years ago

@YuXiaokang I made the one line change, hope that's ok. Thanks again for finding this bug!