mahmoodlab / PANTHER

Morphological Prototyping for Unsupervised Slide Representation Learning in Computational Pathology - CVPR 2024
Other
95 stars 11 forks source link

FAISS K-Means Clustering Quality #12

Closed bryanwong17 closed 2 months ago

bryanwong17 commented 2 months ago

Thank you for the great work! I was wondering if you have tried checking the quality of the assignment clusters using FAISS K-Means. I attempted a similar approach to your code by training FAISS K-Means with the extracted features (train data) on the C16 dataset and clustering them into 16 prototypes (using other hyperparameters similar to yours). Then, I used the trained K-Means model to cluster the patches per WSI (test data). However, the clustering results showed poor performance in the test data and even in the training data (e.g., prototype 1 in WSI A looks very different from prototype 1 in WSI B), even though I used a good feature extractor, such as CONCH.

andrewsong90 commented 2 months ago

Hi @bryanwong17,

Throughout the development, I have checked the quality of prototypes, and they had seemed to make sense to me.

Right now, I can think of two possible reasons

  1. Maybe for C16, where the biological entity of interest is quite small (lymph node metastasis - even within an image patch), clustering and prototyping might not be as effective as in other cancers.
  2. In my personal experience (I might be wrong on this). Simple K-means from sklearn seemed to be "better" than FAISS k-means - The initial development of PANTHER was also using sklearn K-means. Maybe this is worth a try?

Best,

bryanwong17 commented 2 months ago

Hi @andrewsong90,

Thank you for your insights. I was wondering whether you have uploaded the code for inference or cluster assignment. As far as I know, only the code for training K-means (using both sklearn and FAISS K-means) and saving the weights has been uploaded, but not the code for inference or cluster assignments.

If I may, I have a few follow-up questions:

  1. Is it necessary to normalize the training patch features before training K-Means, and should we also normalize the test patch features before performing K-Means inference?
  2. Is using PCA necessary?
  3. Is it necessary to sample the patches equally per WSI, or can I simply use all the patch features from x train WSIs?
  4. Is your num_proto_patches 100,000 or 1,000,000? I saw in the README it was 1,000,000 while in the clustering.sh it was 100,000
  5. Do you train K-Means to learn prototypes using several datasets, or do you train it separately for each dataset?
  6. Could you confirm if the following approaches for assigning cluster labels during inference are correct:

For Sklearn:

from scipy.spatial.distance import cdist
cluster_labels = kmeans.predict(patch_features) # patch features in one test WSI
centroids = kmeans.cluster_centers_
distances = cdist(patch_features, centroids, metric=args.distance_metric)  # distance_metric: 'cosine' or 'euclidean'
distances = np.min(distances, axis=1)

For FAISS:

distances, assignments = index.search(patch_features.numpy(), 1) # patch features in one test WSI
distances, cluster_labels = distances.ravel(), assignments.ravel()

Then, in the inference loop:

for clusternum in range(args.num_proto):
    cluster_indices = np.where(cluster_labels == clusternum)[0]
    cluster_distances = distances[cluster_indices]
    representative_idx = np.argmin(cluster_distances)  # index of the representative prototype in that cluster

Thank you very much for your support!

Best,

HHHedo commented 2 months ago

Hi @andrewsong90, I also have one question about k-means. In your settings, L2 distance is used for k-means, while most Foundation models (e.g., UNI) are DINOv2 pre-trained. Intuitively, using cosine distance is a more reasonable choice. I have tried using Gigapath as the encoder to run PANTHER. Here is one example of the prototypical assignment map. Better performance (both qualitatively and quantitatively) is achieved using cosine distance.
image BTW, there is still one question since L2 distance is used in GMM.

andrewsong90 commented 2 months ago

Hi @bryanwong17,

I have not uploaded the code for inference or cluster assignment, but it should be easy to code up (Your codes for sklearn/FAISS/inference look correct to me)

As for answers to some of your questions

  1. I think it is a design choice and the answer is to try both as @HHHedo is kindly explaining with an example above - While classical clustering approaches would advocate for normalization and centering, I opted for not normalizing for two reasons.

    • I observed that patch features had a range of L2 norm values. I hypothesized that the norm also has important information about the patch.
    • Typical supervised baselines (ABMIL, TransMIL) typically do not involve feature unit normalization step for downstream task, so I wanted to follow the current practice.
  2. I do not think PCA is necessary.

  3. Sampling same number of patches was just out of convenience, but if you can yes probably use all patches!

  4. Probably the larger the better, but as you might have observed already, beyond certain point the impact on the downstream task was minimal

  5. Prototypes are initialized for each dataset. But I would love to see how pan-cancer prototype initialization helps. Make sure to increase the number of prototypes in this case, since different cancers would have non-overlapping prototypes.

Thank you for helping me improve PANTHER further!

andrewsong90 commented 2 months ago

Hi @HHHedo

Thank you very much for providing an illustrative example. As I explained above, the cosine similarity did cross my mind at some point, in accordance with the "right way" of doing clustering. But it is interesting/surpising that feature normalization indeed helps boost the performance further.

Maybe I should update the code base to also include the L2 normalization step before everything - The GMM can still operate based on the L2-normalized feature space!

Thank you so much

bryanwong17 commented 2 months ago

Hi @andrewsong90,

Thank you for the detailed explanations. They really help me understand more about your work, especially in clustering part!