CircleRadon / APro

The code for "Label-efficient Segmentation via Affinity Propagation". [NeurIPS2023]
Apache License 2.0
61 stars 1 forks source link

Help on the weight of aproloss and partial cross entropy loss? #5

Open lauraset opened 7 months ago

lauraset commented 7 months ago

Hi, @CircleRadon. Thank you for your great work. I am not clear about the weight of aproloss and its implementations. According to the issue 3, the implementation of aproloss is:

class AproLoss(nn.Module):
    def __init__(self, ignore_index=255):
        super().__init__()
        # partial cross entropy
        self.partialCE = nn.CrossEntropyLoss(ignore_index=ignore_index)
        # apro
        self.global_apro = Global_APro()
        self.local_apro = Local_APro(kernel_size=5, zeta_s=0.15) #set kernel_size and zeta_s
        self.mst = MinimumSpanningTree(Global_APro.norm2_distance)
        # pca n_component
        # self.q = 1
        self.ignore_index = ignore_index

   def forward(self, x, y_hat, y):
        # x: B, C, H, W
        # y_hat B, classes, H, W
        # partial cross entropy
        partial = self.partialCE(y_hat, y)
        # compute PCA
        # B, 1, H, W
        # pca_imgs = self.compute_pca(x)

        # compute image tree
        # I think directly using x is also fine
        img_mst_tree = self.mst(x)
        # img_mst_tree = self.mst(pca_imgs)

        # y: B, H, W
        # y = y.float()
        y_hat = torch.softmax(y_hat, dim=1) # convert to probability [0,1]

        # psuedo label for global info
        # using low level feature
        soft_pseudo = self.global_apro(y_hat, x, img_mst_tree, zeta_g=0.001)
        # using deep feature
        soft_pseudo = self.global_apro(soft_pseudo, y_hat, img_mst_tree, zeta_g=0.05)

        # unlabelled region only
        unlabelled_regions = (y.unsqueeze(1) == self.ignore_index)

        # compute difference between generated psuedo labels and predicted one
        loss_global_term = torch.abs(soft_pseudo-y_hat) * unlabelled_regions
        # normalize the loss
        n_regions = unlabelled_regions.sum().clamp(min=1)
        loss_global = loss_global_term.sum() / n_regions

        # local term
        soft_pseudo = self.local_apro(pca_imgs, y_hat)
        loss_local_term = torch.abs(y_hat - soft_pseudo) * unlabelled_regions
        loss_local_term = loss_local_term.sum() / unlabelled_regions.sum().clamp(min=1)
        loss_local = loss_local_term
        return partial + loss_global + loss_local

I have several questions:

  1. How to set the weight of partial cross entropy and global/local apro loss?
  2. For global apro, the deep feature is directly set to y_hat. Is this the defualt setting in your paper? Why it should be set as the last feature map from the segmentation network?

Thank you in advance.