sthalles / SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://sthalles.github.io/simple-self-supervised-learning/
MIT License
2.19k stars 457 forks source link

- attempt to add support for n_views >=3 #67

Open CielAl opened 10 months ago

CielAl commented 10 months ago

@sthalles @alessiamarcolini @butyuhao Hi,

as mentioned in #32 , the current implementation of info_nce_loss may not properly work if n_views > 2 due to the additional positive pairs. Herein I attempt to fix it by duplicate the negative pairs for additional positive ones, if I understand the mechanism of your current implementation correctly:

        positives = similarity_matrix[labels.bool()].view(labels.shape[0] * (n_views - 1), -1)

        # select only the negatives
        # change: copy if n_views > 2 for other positive pairs of img
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).repeat(n_views - 1, 1)

        logits = torch.cat([positives, negatives], dim=1)