Closed divyanshj16 closed 3 years ago
@justchenhao
class BCL(nn.Module): """ batch-balanced contrastive loss no-change,1 change,-1 """ def __init__(self, margin=2.0): super(BCL, self).__init__() self.margin = margin def forward(self, distance, label): label[label==255] = 1 mask = (label != 255).float() distance = distance * mask pos_num = torch.sum((label==1).float())+0.0001 neg_num = torch.sum((label==-1).float())+0.0001 loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num loss_2 = torch.sum((1-label) / 2 * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) ) / neg_num loss = loss_1 + loss_2 return loss
In this code why are you doing label[label==255] = 1, because label is already transformed in -1 and 1. The mask tensor below it will also be always completely 1.
label[label==255] = 1
mask
Thanks for your attention. The mask tensor may have a value of 255 due to the image augmentation of rotation.
@justchenhao
In this code why are you doing
label[label==255] = 1
, because label is already transformed in -1 and 1. Themask
tensor below it will also be always completely 1.