motiwari / BanditPAM

BanditPAM C++ implementation and Python package
MIT License
647 stars 38 forks source link

User report: algorithm choosing wrong arms (correctness) #252

Closed motiwari closed 1 year ago

motiwari commented 1 year ago

A report from a user, I've asked for his code:

I compared the BUILD step solution's objective of BanditPAM with that of the BUILD step solution's objective of PAM (on MNIST dataset, train split, N = 60k, d = 784). For the number of medoids k = 100 and L2 distance, PAM's BUILD objective is 6.0641 while that of BanditPAM's BUILD is 6.55 (around 8% worse). I have observed such differences for various other randomly generated datasets as well. For the MNIST dataset, I had set build_confidence = 16. According to my understanding of BanditPAM's paper and code, this should correspond to delta = (Nexp(build_confidence))^{-1} = 1.8756e-12. Hence, according to Theorem 1's Remark A1 of BanditPAM paper, the probability that all confidence intervals are true confidence intervals is at least 1 - 2N^2delta = 0.9932. My understanding is that this probability should be high enough for me to observe the same BUILD objective for both BanditPAM and PAM for most random seeds. However, it does not happen for any seed.

Please note that my understanding of the theoretical results or the way I used build_confidence may be wrong. Hence, it would be great if you could resolve the above issue. Overall, based on the results shown in the BanditPAM paper, I expected the objectives to be the same in most of my runs (on various datasets with 1 - 2N^2*delta > 0.99, computed through build_confidence as explained above). However, it does not happen and the objectives are quite far apart (e.g., 8% difference).

motiwari commented 1 year ago

Code to reproduce bug:

# code starts
from banditpam import KMedoids
import numpy as np
from scipy.spatial import  distance_matrix
from sklearn_extra import cluster

rand_seed = 10
np.random.seed(rand_seed)

n = 20000
d = 10
X = np.random.rand(n,d)

# BanditPAM's BUILD run

k = 100
delta1 = 2*k/n # this value of delta1 should correspond to \delta = n^{-3} as in Theorem 1 of BanditPAM
useCache = False
maxIter = 0
buildConfidence = np.floor(np.log(2*n*k/delta1)).astype(np.int64)
# delta1 = 2*k/n makes buildConfidence ~ log(n^2) (subject to integer rounding) and overall \delta ~ n^{-3}

kmed = KMedoids(n_medoids=k, algorithm="BanditPAM", build_confidence = buildConfidence, use_cache = useCache, \
max_iter = maxIter)
kmed.fit(X, 'L2')

banditpam_build_medoids_idx = kmed.build_medoids
banditpam_build_medoids = X[banditpam_build_medoids_idx,:]

banditpam_medoids_ref_cost_distance_matrix = distance_matrix(banditpam_build_medoids,X)
banditpam_objective = np.sum(np.min(banditpam_medoids_ref_cost_distance_matrix,0))

# PAM's build run, using sklearn's KMedoids BUILD step
sklearn_kmed = cluster.KMedoids(n_clusters=k, metric='euclidean', method='pam', init='build', max_iter = 0).fit(X)

sklearn_build_medoids_idx = sklearn_kmed.medoid_indices_
sklearn_build_medoids = X[sklearn_build_medoids_idx,:]

sklearn_medoids_ref_cost_distance_matrix2 = distance_matrix(sklearn_build_medoids,X)
sklearn_objective2 = np.sum(np.min(sklearn_medoids_ref_cost_distance_matrix2,0))

print('BanditPAM BUILD objective: ', banditpam_objective, ' sklearn KMedoids BUILD objective: ', sklearn_objective2)
print('Out of ', k, ', common medoids selected by the two algorithms: ', \
len(np.intersect1d(sklearn_build_medoids_idx,banditpam_build_medoids_idx)))
#sklearn_cluster_centers = sklearn_kmed.cluster_centers_
#sklearn_medoids_ref_cost_distance_matrix1 = distance_matrix(sklearn_cluster_centers,X)
#sklearn_objective1 = np.sum(np.min(sklearn_medoids_ref_cost_distance_matrix1,0)) #gives same output as sklearn_objective2

# code ends
motiwari commented 1 year ago

@lukeleeai can you take a look and see if you can reproduce this error? It's probably related to some of the correctness issues we're seeing with the loss in BanditPAM vs. BanditPAM++

lukeleeai commented 1 year ago

Thanks for sharing this. I will add it to my todo!

lukeleeai commented 1 year ago

The corrected code is proposed in the Pull Request!

motiwari commented 1 year ago

@lukeleeai it looks like the linked PR was closed with unmerged commits. Is that intentional?

The user is still reporting that this is an issue. Could you verify this issue is resolved in v4.0.2? Or do we need to wait until v4.0.3?

image
lukeleeai commented 1 year ago

I closed it because debug_loss had experiment results! Afterwards, I opened a new PR (fixed_loss) and it was integrated into 4.0.3!

On Tue, Jul 4, 2023 at 2:50 AM Mo Tiwari @.***> wrote:

@lukeleeai https://github.com/lukeleeai it looks like the linked PR was closed with unmerged commits. Is that intentional?

The user is still reporting that this is an issue. Could you verify this issue is resolved in v4.0.2? Or do we need to wait until v4.0.3? [image: image] https://user-images.githubusercontent.com/13426340/250631993-851220a8-9377-4b78-abd9-af666de3b451.png

— Reply to this email directly, view it on GitHub https://github.com/motiwari/BanditPAM/issues/252#issuecomment-1618943081, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF5DVNPK2TKCMEO7RJEG47DXOMA47ANCNFSM6AAAAAAX75KIFI . You are receiving this because you were mentioned.Message ID: @.***>

motiwari commented 1 year ago

Oh right! So this is currently expected to still fail in v4.0.2 and will be fixed when @Adarsh321123 ships v4.0.3, correct? @lukeleeai

lukeleeai commented 1 year ago

Yes that is correct!

On Tue, Jul 4, 2023 at 11:22 PM Mo Tiwari @.***> wrote:

Oh right! So this is currently expected to still fail in v4.0.2 and will be fixed when @Adarsh321123 https://github.com/Adarsh321123 ships v4.0.3, correct? @lukeleeai https://github.com/lukeleeai

— Reply to this email directly, view it on GitHub https://github.com/motiwari/BanditPAM/issues/252#issuecomment-1620338994, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF5DVNL43IK6VD3EKYQDLKTXOQRK5ANCNFSM6AAAAAAX75KIFI . You are receiving this because you were mentioned.Message ID: @.***>