vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.38k stars 182 forks source link

Help request to create new loss #371

Closed miguel-arrf closed 9 months ago

miguel-arrf commented 9 months ago

Hi!

I understand if this issue is closed, as it is not directly related to solo_learn. However, I'm building a custom loss for SimCLR, inspired by https://arxiv.org/abs/2006.10511. Essentially I'm using all of solo-learn code, but with a custom BatchSampler where I define my batch, and I've also created a custom loss function.

Essentially, I have an NxN matrix that I load, defining, for each pair of images, -1 if that pair (x_i, x_j) should be considered a negative pair for the loss calculation of that image (x_i). 0 if it should be ignored, and 1 if for that row we should consider more positive pairs (not only between x_i and its augmented version but also with other images x_n).

Here's the modified training_stepfunction:

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
        out = super().training_step(batch, batch_idx)
        class_loss = out["loss"]
        z = torch.cat(out["z"])

        matrix = torch.load(f"{self.matrix_file}").to(self.device)

        nce_loss = simclr_loss_func(
            z,
            temperature=self.temperature,
            label_matrix=matrix,
        )

        return nce_loss

And the simclr_loss_func function:

def simclr_loss(representations, similarity_labels, temperature=0.1):
    representations = F.normalize(representations, dim=1)
    similarity_matrix = representations @ representations.t()

    positive_mask = (similarity_labels == 1).float()
    negative_mask = (similarity_labels == 0).float()
    ignore_mask = (similarity_labels == -1).float()

    negatives = similarity_matrix * negative_mask

    denominators = torch.sum(torch.exp(negatives / temperature) * negative_mask, dim=1)
    positives = torch.exp((similarity_matrix * positive_mask) / temperature) * positive_mask

    division_of_each_row_by_denominator = positives / (denominators.view(-1, 1) + positives)

    sum_of_log = torch.sum(torch.nan_to_num(-torch.log(division_of_each_row_by_denominator), posinf=0, neginf=0))

    loss = sum_of_log / torch.sum(positive_mask)
    return loss

I know this is unusual, but, would anyone mind giving me some hint on why is my loss practically not chancing (ranging between 4.15 and 4.2)...?

Can it be because the backpropagation gradient is being broken somewhere? Is there any issue with my loss function? Essentially I want to be the one defining what are the positive and negative pairs for the SimCLR loss calculation.

My batch is of shape [80, 256].

Thank you for the help!