Closed Hugo-cell111 closed 1 year ago
Thanks!
Sorry for the misleading. We actually do not change the order of the two kinds of logits. The KL loss is calculated from logits(strong) to logits(weak).
softmax_pred_u_w = F.softmax(pred_u_w.detach(), dim=1)
logsoftmax_pred_u_s1 = F.log_softmax(pred_u_s1, dim=1)
loss_u_kl_sa2wa = criterion_kl(logsoftmax_pred_u_s1, softmax_pred_u_w)
loss_u_kl_sa2wa = torch.sum(loss_u_kl_sa2wa, dim=1) * conf_fliter_u_w
loss_u_kl_sa2wa = torch.sum(loss_u_kl_sa2wa) / torch.sum(ignore_mask_cutmixed1 != 255).item()
loss_u_corr_w = criterion_u(pred_u_w_corr, mask_u_w)
loss_u_corr_w = loss_u_corr_w * ((conf_u_w >= thresh_global) & (ignore_mask != 255))
loss_u_corr_w = torch.sum(loss_u_corr_w) / torch.sum(ignore_mask != 255).item()
You can refer to our code for more details.
Hi! For the equation(4), it is written as KL(logits(weak), logits(strong)), but for the KL(P||Q), P usually represents the real distribution, while Q represents the optimal/ideal distribution, so I think do we need to change the order of two kinds of logits? Plus, In formula (9) first term, does the pseudo label of weak augmented branch serve as the pseudo supervision of the correlation map of its own? Thanks!