Lee-Gihun / MEDIAR

(NeurIPS 2022 CellSeg Challenge - 1st Winner) Open source code for "MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy"
MIT License
141 stars 29 forks source link

knn classifier #9

Closed Eggwardhan closed 7 months ago

Eggwardhan commented 10 months ago

First and foremost, I wish to express my profound respect for the research presented in your paper titled "[MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy]". The approach you introduced about using model embeddings for KNN classification has greatly intrigued me.

As I attempt to understand and apply this method, I have encountered some challenges in its practical implementation. To gain a deeper understanding and practical experience of this concept, I am reaching out to inquire if it would be possible for you to share the related code implementation, be it a GitHub link or a zip file, either would be immensely helpful to me.

I assure you that, if granted access to your code, it will strictly be used for academic research purposes, adhering to all relevant research ethics and copyright policies. Your assistance will be invaluable in the completion of my study.

Thank you for considering my request. I look forward to your response and once again, thank you for your valuable contribution to the academic community.

Lee-Gihun commented 7 months ago

Since we didn't use a KNN classifier, I believe you're referring to our K-means approach. Unfortunately, we don't have official code for this. However, it can be easily implemented in just a few lines. Let me provide the following code, which is nearly identical to our original approach. The only difference is that we ran it multiple times with different encoder checkpoints and seeds to achieve the best results during the competition period.

The result is as follows:

d50780de-63a2-41da-ba85-2d6e2cf470d0


import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from train_tools.models import MEDIARFormer
from train_tools.data_utils.datasetter import get_dataloaders_labeled

# Load data and model
dataloaders = get_dataloaders_labeled("/home/gihun/MEDIAR", "./train_tools/data_utils/mapping_labeled.json", "./train_tools/data_utils/mapping_tuning.json")
model = MEDIARFormer().to("cuda:7").eval()
model.load_state_dict(torch.load("./weights/pretrained/phase1.pth", map_location="cpu"))

# Extract and process embeddings
embeddings_all = []
for batch in dataloaders["train"]:
    with torch.no_grad():
        embeddings = model.encoder(batch["img"].to("cuda:7"))[-1].mean(dim=[2, 3]).cpu()
        embeddings_all.append(embeddings)
embeddings_all = torch.cat(embeddings_all, dim=0)

# Cluster embeddings
kmeans = KMeans(n_clusters=40, random_state=0).fit(embeddings_all.numpy())
labels, counts = np.unique(kmeans.labels_, return_counts=True)
sorted_indices = np.argsort(counts)[::-1]

# Plot histogram
plt.figure(figsize=(10, 6))
plt.bar(range(len(counts)), counts[sorted_indices], tick_label=labels[sorted_indices])
plt.xlabel('Cluster Label')
plt.ylabel('Count')
plt.title('KMeans Labels Histogram (Desc. Count)')
plt.show()```