BBBBchan / CorrMatch

Official code for "CorrMatch: Label Propagation via Correlation Matching for Semi-Supervised Semantic Segmentation"
117 stars 8 forks source link

Question about some formula #3

Closed Hugo-cell111 closed 1 year ago

Hugo-cell111 commented 1 year ago

Hi! For the equation(4), it is written as KL(logits(weak), logits(strong)), but for the KL(P||Q), P usually represents the real distribution, while Q represents the optimal/ideal distribution, so I think do we need to change the order of two kinds of logits? Plus, In formula (9) first term, does the pseudo label of weak augmented branch serve as the pseudo supervision of the correlation map of its own? Thanks!

BBBBchan commented 1 year ago

Thanks!

  1. Sorry for the misleading. We actually do not change the order of the two kinds of logits. The KL loss is calculated from logits(strong) to logits(weak).

            softmax_pred_u_w = F.softmax(pred_u_w.detach(), dim=1)
            logsoftmax_pred_u_s1 = F.log_softmax(pred_u_s1, dim=1)
    
            loss_u_kl_sa2wa = criterion_kl(logsoftmax_pred_u_s1, softmax_pred_u_w)
            loss_u_kl_sa2wa = torch.sum(loss_u_kl_sa2wa, dim=1) * conf_fliter_u_w
            loss_u_kl_sa2wa = torch.sum(loss_u_kl_sa2wa) / torch.sum(ignore_mask_cutmixed1 != 255).item()
  2. The pseudo label of weak augmented branch $p^w$ serves as the supervision of the correlation branch.
            loss_u_corr_w = criterion_u(pred_u_w_corr, mask_u_w)
            loss_u_corr_w = loss_u_corr_w * ((conf_u_w >= thresh_global) & (ignore_mask != 255))
            loss_u_corr_w = torch.sum(loss_u_corr_w) / torch.sum(ignore_mask != 255).item()

You can refer to our code for more details.