justchenhao / STANet

official implementation of the spatial-temporal attention neural network (STANet) for remote sensing image change detection
BSD 2-Clause "Simplified" License
410 stars 110 forks source link

Issue in BCL Loss. #32

Closed divyanshj16 closed 3 years ago

divyanshj16 commented 4 years ago

@justchenhao

class BCL(nn.Module):
    """
    batch-balanced contrastive loss
    no-change,1
    change,-1
    """

    def __init__(self, margin=2.0):
        super(BCL, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        label[label==255] = 1
        mask = (label != 255).float()
        distance = distance * mask
        pos_num = torch.sum((label==1).float())+0.0001
        neg_num = torch.sum((label==-1).float())+0.0001

        loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
        loss_2 = torch.sum((1-label) / 2 *
            torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
        ) / neg_num
        loss = loss_1 + loss_2
        return loss

In this code why are you doing label[label==255] = 1, because label is already transformed in -1 and 1. The mask tensor below it will also be always completely 1.

justchenhao commented 3 years ago

Thanks for your attention. The mask tensor may have a value of 255 due to the image augmentation of rotation.