mehta-lab / VisCy

computer vision models for single-cell phenotyping
https://pypi.org/project/viscy/
BSD 3-Clause "New" or "Revised" License
26 stars 3 forks source link

Update the projection head (normalization and size). #139

Open mattersoflight opened 4 weeks ago

mattersoflight commented 4 weeks ago

TL;DR: current projection head doesn't do what it is supposed to do. It should have a batch norm and the size of the features and projections may be reduced further. In the previous implementations of contrastive learning models (dynacontrast), we used batch norm in the projection head after each MLP.

(projection): Sequential(
    (fc1): Linear(in_features=2048, out_features=2048, bias=False)
    (bn1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=2048, out_features=128, bias=False)
    (bn2): BatchNorm1dNoBias(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )

This paper also recommends using a non-linear projection head with batch norm. Different projection heads are evaluated by comparing the rank of the features (# of independent features) before and after the projection head: image

As expected, the rank(projections) << rank(features).

Our current model's behavior is the opposite: rank(projections)> rank(features) as seen from the examination of the principal components in each.

image

This seems to be the consequence of clipping of projections, which seems to be due to the use of ReLU without normalization.

plt.plot(np.mean(embedding_dataset["projections"].values,axis=1)) image

plt.plot(np.std(embedding_dataset["projections"].values,axis=1)) image

ziw-liu commented 3 weeks ago

@mattersoflight In SimCLR and others using InfoNCE-style losses, there is an implicit L2-normalization of $\mathbb{z}$ happening in the loss function, since they use cosine similarity as the distance function. The triplet margin loss uses L2 distance (which is fully determined by the cosine and vice versa for unit vectors). I still need to find a reference implementation, but the planned removal of L2 normalization might not be needed.

mattersoflight commented 3 weeks ago

@ziw-liu L2 normalization is indeed equivalent to converting feature vectors into unit vectors. It also makes sense that the loss is computed between unit vectors (either cosine similarity or eucledian) given SimCLR paper.

I agree that L2 normalization of projections doesn't need to be removed. NO NEED to implement that as an argument.

[TripletMarginLoss](https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html) provided by torch.nn doesn't do normalization by default.

mattersoflight commented 1 week ago

@ziw-liu the paper that compares different types of projection heads uses infoNCE loss (which reads the same as NT-Xent loss in simCLR). It may be that using NT-Xent loss promotes higher-rank embeddings.

ziw-liu commented 5 days ago

Model changes was implemented in #145. But #154 can potentially fix the low-rank feature map.