mhamilton723 / STEGO

Unsupervised Semantic Segmentation by Distilling Feature Correspondences
MIT License
712 stars 143 forks source link

Possible mistake in cluster probe loss function? #32

Open PeaBrane opened 2 years ago

PeaBrane commented 2 years ago

I'm under the impression that the loss function of the cluster probe is simply the entropy of the cluster probabilities.

However, in the ClusterLookup class in modules.py, the loss function is defined as

cluster_loss = -(cluster_probs * inner_products).sum(1).mean()

Shouldn't this instead be

cluster_loss = -(cluster_probs * cluster_probs.log()).sum(1).mean()

Or alternatively (assuming alpha = 1),

cluster_probs_log = inner_products - inner_products.exp().sum(1, keepdims=True).log()
cluster_loss = -(cluster_probs * cluster_probs_log).sum(1).mean()