Open PeaBrane opened 2 years ago
I'm under the impression that the loss function of the cluster probe is simply the entropy of the cluster probabilities.
However, in the ClusterLookup class in modules.py, the loss function is defined as
cluster_loss = -(cluster_probs * inner_products).sum(1).mean()
Shouldn't this instead be
cluster_loss = -(cluster_probs * cluster_probs.log()).sum(1).mean()
Or alternatively (assuming alpha = 1),
cluster_probs_log = inner_products - inner_products.exp().sum(1, keepdims=True).log() cluster_loss = -(cluster_probs * cluster_probs_log).sum(1).mean()
I'm under the impression that the loss function of the cluster probe is simply the entropy of the cluster probabilities.
However, in the ClusterLookup class in modules.py, the loss function is defined as
Shouldn't this instead be
Or alternatively (assuming alpha = 1),