yiskw713 / boundary_loss_for_remote_sensing

Pytorch re-implementation of boundary loss, proposed in "Boundary Loss for Remote Sensing Imagery Semantic Segmentation"
98 stars 13 forks source link

one-hot function #4

Open xuyingxiao opened 4 years ago

xuyingxiao commented 4 years ago

Hi, Thanks for your re-implementation. It seems work well. But I can't understand the one-hot function, it seems like a bitwise_not function, in line 9-10: one_hot_label = torch.eye(n_classes, device=device, requires_grad=requires_grad)[label] For binary segmentaion, what's the mean of tensor([[1.]], device='cuda:0', requires_grad=True)[label](where label is a 512*512 binary mask)

ashnair1 commented 4 years ago

@yiskw713

A further question regarding one hot label I had was, why use

one_hot_label = torch.eye(
        n_classes, device=device, requires_grad=requires_grad)[label]

when you could simply use

one_hot_label = F.one_hot(label, n_classes) ?

SonwYang commented 2 years ago

class BoundaryLoss(nn.Module): """Boundary Loss proposed in: Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation https://arxiv.org/abs/1905.07852 """

def __init__(self, theta0=3, theta=5):
    super().__init__()

    self.theta0 = theta0
    self.theta = theta

def crop(self, w, h, target):
    nt, ht, wt = target.size()
    offset_w, offset_h = (wt - w) // 2, (ht - h) // 2
    if offset_w > 0 and offset_h > 0:
        target = target[:, offset_h:-offset_h, offset_w:-offset_w]

    return target

def to_one_hot(self, target, size):
    n, c, h, w = size

    ymask = torch.FloatTensor(size).zero_()
    new_target = torch.LongTensor(n, 1, h, w)
    if target.is_cuda:
        ymask = ymask.cuda(target.get_device())
        new_target = new_target.cuda(target.get_device())

    new_target[:, 0, :, :] = torch.clamp(target.detach(), 0, c - 1)
    ymask.scatter_(1, new_target, 1.0)

    return torch.autograd.Variable(ymask)

def forward(self, pred, gt):
    """
    Input:
        - pred: the output from model (before softmax)
                shape (N, C, H, W)
        - gt: ground truth map
                shape (N, H, w)
    Return:
        - boundary loss, averaged over mini-bathc
    """
    gt = torch.squeeze(gt)

    n, c, h, w = pred.shape
    log_p = F.log_softmax(pred, dim=1)

    # softmax so that predicted map can be distributed in [0, 1]
    pred = torch.softmax(pred, dim=1)

    # one-hot vector of ground truth
    gt = self.crop(w, h, gt)
    one_hot_gt = self.to_one_hot(gt, log_p.size())

    # boundary map
    gt_b = F.max_pool2d(
        1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
    gt_b -= 1 - one_hot_gt

    pred_b = F.max_pool2d(
        1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
    pred_b -= 1 - pred

    # extended boundary map
    gt_b_ext = F.max_pool2d(
        gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

    pred_b_ext = F.max_pool2d(
        pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

    # reshape
    gt_b = gt_b.view(n, c, -1)
    pred_b = pred_b.view(n, c, -1)
    gt_b_ext = gt_b_ext.view(n, c, -1)
    pred_b_ext = pred_b_ext.view(n, c, -1)

    # Precision, Recall
    P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
    R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)

    # Boundary F1 Score
    BF1 = 2 * P * R / (P + R + 1e-7)

    # summing BF1 Score for each class and average over mini-batch
    loss = torch.mean(1 - BF1)

    return loss
justalonelc commented 2 years ago

类边界损失(nn.模块):“”“边界损失提出:Alexey Bokhovkin等人,遥感影像语义分割的边界损失 https://arxiv.org/abs/1905.07852 """

def __init__(self, theta0=3, theta=5):
    super().__init__()

    self.theta0 = theta0
    self.theta = theta

def crop(self, w, h, target):
    nt, ht, wt = target.size()
    offset_w, offset_h = (wt - w) // 2, (ht - h) // 2
    if offset_w > 0 and offset_h > 0:
        target = target[:, offset_h:-offset_h, offset_w:-offset_w]

    return target

def to_one_hot(self, target, size):
    n, c, h, w = size

    ymask = torch.FloatTensor(size).zero_()
    new_target = torch.LongTensor(n, 1, h, w)
    if target.is_cuda:
        ymask = ymask.cuda(target.get_device())
        new_target = new_target.cuda(target.get_device())

    new_target[:, 0, :, :] = torch.clamp(target.detach(), 0, c - 1)
    ymask.scatter_(1, new_target, 1.0)

    return torch.autograd.Variable(ymask)

def forward(self, pred, gt):
    """
    Input:
        - pred: the output from model (before softmax)
                shape (N, C, H, W)
        - gt: ground truth map
                shape (N, H, w)
    Return:
        - boundary loss, averaged over mini-bathc
    """
    gt = torch.squeeze(gt)

    n, c, h, w = pred.shape
    log_p = F.log_softmax(pred, dim=1)

    # softmax so that predicted map can be distributed in [0, 1]
    pred = torch.softmax(pred, dim=1)

    # one-hot vector of ground truth
    gt = self.crop(w, h, gt)
    one_hot_gt = self.to_one_hot(gt, log_p.size())

    # boundary map
    gt_b = F.max_pool2d(
        1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
    gt_b -= 1 - one_hot_gt

    pred_b = F.max_pool2d(
        1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
    pred_b -= 1 - pred

    # extended boundary map
    gt_b_ext = F.max_pool2d(
        gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

    pred_b_ext = F.max_pool2d(
        pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

    # reshape
    gt_b = gt_b.view(n, c, -1)
    pred_b = pred_b.view(n, c, -1)
    gt_b_ext = gt_b_ext.view(n, c, -1)
    pred_b_ext = pred_b_ext.view(n, c, -1)

    # Precision, Recall
    P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
    R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)

    # Boundary F1 Score
    BF1 = 2 * P * R / (P + R + 1e-7)

    # summing BF1 Score for each class and average over mini-batch
    loss = torch.mean(1 - BF1)

    return loss

thank you for your code