Project-MONAI / research-contributions

Implementations of recent research prototypes/demonstrations using MONAI.
https://monai.io/
Apache License 2.0
998 stars 330 forks source link

Understanding the contrastive loss implementation #389

Open jwutsetro opened 4 months ago

jwutsetro commented 4 months ago

Dear,

I am trying to understand your custom contrastive loss class. How I understand it, it correctly computes the positives by shifting the diagonal by batch_size and - batch_size to compute the nominator. But then in the compute of the denominator, the negative mask is defined by the inverse of torch.eye(). As I understand, this means that only the self similarity ( which is always 1) is removed from the denominator but the similarities between a patch and it's augmented version are still included?

I would personally implement it like this:

    def create_negative_mask(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

Will this not result in an unstable training? I am asking because I don't seem to be able to get the contrastive loss to decrease. Attached is my total loss for a batch size of 12,24 and 48. The rotational loss and reconstruction loss are 0.3 and 0.1 for all models respectively, so the total loss is dominated by the contrastive loss not going down.

Screenshot 2024-05-30 at 12 43 48

Kindly, Joris

jwutsetro commented 4 months ago

After digging in it a bit more, It seems that the implementations of SIMCLR indeed follow a similar approach. But that still leaves me wondering why we would include the positives in the denominator ? Additionally, if anyone has some suggestions on how to further improve the contrastive loss optimisation, I would be very happy to hear them !