adambielski / siamese-triplet

Siamese and triplet networks with online pair/triplet mining in PyTorch
BSD 3-Clause "New" or "Revised" License
3.09k stars 634 forks source link

Implementation of Contrastive Loss #46

Closed ChenPaulYu closed 4 years ago

ChenPaulYu commented 4 years ago

Thanks for your great work, the code is clean and it is work appropriate. However, I got some problem about your implementation about contrastive loss. Is it possible to explain it more details, because it is quite different with the description of original paper.

In the mean time, I try some contrastive loss implementation by others, but it is not work, is it possible to explain the difference?

# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + 
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))     

        return loss_contrastive
adambielski commented 4 years ago

I don't see how my implementation is different from the paper. The snippet you provide has a different order of power/clamping than the original.