marialeyvallina / generalized_contrastive_loss

MIT License
87 stars 14 forks source link

GCL reproduce #5

Closed wenjie710 closed 2 years ago

wenjie710 commented 2 years ago

Hi, I am trying to reproduce the results in the paper using ResNet152 as backbone and GCL loss. However, I get nan error during training. Here is my implementation for GCL loss.

class GCLoss(torch.nn.Module):  

    def __init__(self, margin=0.5):
        super(GCLoss, self).__init__()
        self.margin = margin

    def forward(self, out1, out2, label): # out1, out2 are the output features,  and label is the FoV value
        dist = torch.sqrt(torch.sum((out1 - out2) ** 2, dim=1))
        dist_square = torch.sum((out1 - out2) ** 2, dim=1)
        loss = label * 0.5 * dist_square + (1 - label) * 0.5 * torch.relu(self.margin - dist) ** 2
        # embed()
        if torch.isnan(loss).sum() > 0:
            embed()
        loss_m = torch.mean(loss)
        return loss_m

Could you please give me some advice? Do I get something wrong?

wenjie710 commented 2 years ago

Do you perform L2 normalize after GeM pooling during training? What is the range of the distance in GCL loss, [0, +inf] or [0, 2]?

marialeyvallina commented 2 years ago

Dear @wenjie710 At first sight, your code looks fine to me, but I haven't debugged it. I do not perform the nan check though, since I did not get nan values in my experiments. I do perform L2 normalization after GeM pooling, and therefore the euclidean distance has to be necessarily between 0 and 2, the maximum possible Euclidean distance between vectors of norm = 1. Please let me know if the L2 normalization solves your issue.

wenjie710 commented 2 years ago

Thanks a lot. Normalization solves the issue.

marialeyvallina commented 2 years ago

Great to hear! I'm closing this, but feel free to re-open or drop me an email if you have further questions.