Hi, @CircleRadon. Thank you for your great work. I am not clear about the weight of aproloss and its implementations.
According to the issue 3, the implementation of aproloss is:
class AproLoss(nn.Module):
def __init__(self, ignore_index=255):
super().__init__()
# partial cross entropy
self.partialCE = nn.CrossEntropyLoss(ignore_index=ignore_index)
# apro
self.global_apro = Global_APro()
self.local_apro = Local_APro(kernel_size=5, zeta_s=0.15) #set kernel_size and zeta_s
self.mst = MinimumSpanningTree(Global_APro.norm2_distance)
# pca n_component
# self.q = 1
self.ignore_index = ignore_index
def forward(self, x, y_hat, y):
# x: B, C, H, W
# y_hat B, classes, H, W
# partial cross entropy
partial = self.partialCE(y_hat, y)
# compute PCA
# B, 1, H, W
# pca_imgs = self.compute_pca(x)
# compute image tree
# I think directly using x is also fine
img_mst_tree = self.mst(x)
# img_mst_tree = self.mst(pca_imgs)
# y: B, H, W
# y = y.float()
y_hat = torch.softmax(y_hat, dim=1) # convert to probability [0,1]
# psuedo label for global info
# using low level feature
soft_pseudo = self.global_apro(y_hat, x, img_mst_tree, zeta_g=0.001)
# using deep feature
soft_pseudo = self.global_apro(soft_pseudo, y_hat, img_mst_tree, zeta_g=0.05)
# unlabelled region only
unlabelled_regions = (y.unsqueeze(1) == self.ignore_index)
# compute difference between generated psuedo labels and predicted one
loss_global_term = torch.abs(soft_pseudo-y_hat) * unlabelled_regions
# normalize the loss
n_regions = unlabelled_regions.sum().clamp(min=1)
loss_global = loss_global_term.sum() / n_regions
# local term
soft_pseudo = self.local_apro(pca_imgs, y_hat)
loss_local_term = torch.abs(y_hat - soft_pseudo) * unlabelled_regions
loss_local_term = loss_local_term.sum() / unlabelled_regions.sum().clamp(min=1)
loss_local = loss_local_term
return partial + loss_global + loss_local
I have several questions:
How to set the weight of partial cross entropy and global/local apro loss?
For global apro, the deep feature is directly set to y_hat. Is this the defualt setting in your paper? Why it should be set as the last feature map from the segmentation network?
Hi, @CircleRadon. Thank you for your great work. I am not clear about the weight of aproloss and its implementations. According to the issue 3, the implementation of aproloss is:
I have several questions:
y_hat
. Is this the defualt setting in your paper? Why it should be set as the last feature map from the segmentation network?Thank you in advance.