microsoft / ProDA

Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021)
https://arxiv.org/abs/2101.10979
MIT License
286 stars 44 forks source link

The kl_div loss of self distillation #43

Open luyvlei opened 2 years ago

luyvlei commented 2 years ago

The following code calculate the kl_div loss of teacher from stage 1 and the student model. But the student didn't calculate log_softmax. Is this a mistake?

    student = F.softmax(target_out['out'], dim=1)
    with torch.no_grad():
        teacher_out = self.teacher_DP(target_imageS)
        teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        teacher = F.softmax(teacher_out['out'], dim=1)

    loss_kd = F.kl_div(student, teacher, reduction='none')
    mask = (teacher != 250).float()
    loss_kd = (loss_kd * mask).sum() / mask.sum()
    loss = loss + self.opt.distillation * loss_kd   
panzhang0104 commented 2 years ago

Yeah, should be log_softmax