Open mattersoflight opened 4 weeks ago
@mattersoflight In SimCLR and others using InfoNCE-style losses, there is an implicit L2-normalization of $\mathbb{z}$ happening in the loss function, since they use cosine similarity as the distance function. The triplet margin loss uses L2 distance (which is fully determined by the cosine and vice versa for unit vectors). I still need to find a reference implementation, but the planned removal of L2 normalization might not be needed.
@ziw-liu L2 normalization is indeed equivalent to converting feature vectors into unit vectors. It also makes sense that the loss is computed between unit vectors (either cosine similarity or eucledian) given SimCLR paper.
I agree that L2 normalization of projections doesn't need to be removed. NO NEED to implement that as an argument.
[TripletMarginLoss](https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html)
provided by torch.nn doesn't do normalization by default.
@ziw-liu the paper that compares different types of projection heads uses infoNCE loss (which reads the same as NT-Xent loss in simCLR). It may be that using NT-Xent loss promotes higher-rank embeddings.
Model changes was implemented in #145. But #154 can potentially fix the low-rank feature map.
TL;DR: current projection head doesn't do what it is supposed to do. It should have a batch norm and the size of the features and projections may be reduced further. In the previous implementations of contrastive learning models (
dynacontrast
), we used batch norm in the projection head after each MLP.This paper also recommends using a non-linear projection head with batch norm. Different projection heads are evaluated by comparing the rank of the features (# of independent features) before and after the projection head:
As expected, the rank(projections) << rank(features).
Our current model's behavior is the opposite: rank(projections)> rank(features) as seen from the examination of the principal components in each.
This seems to be the consequence of clipping of projections, which seems to be due to the use of ReLU without normalization.
plt.plot(np.mean(embedding_dataset["projections"].values,axis=1))
plt.plot(np.std(embedding_dataset["projections"].values,axis=1))