irfanICMLL / structure_knowledge_distillation

The official code for the paper 'Structured Knowledge Distillation for Semantic Segmentation'. (CVPR 2019 ORAL) and extension to other tasks.
BSD 2-Clause "Simplified" License
702 stars 104 forks source link

Hi #2

Closed monk42 closed 5 years ago

monk42 commented 5 years ago

which part is the codes of Pixel-wise distillation , Thanks!

irfanICMLL commented 5 years ago

Just insert this part in your training code:

class CriterionKD(nn.Module): ''' knowledge distillation loss '''

def __init__(self, ignore_index=255, upsample=False, use_weight=True, T=1, sp=0, pp=0):
    super(CriterionKD, self).__init__()
    self.ignore_index = ignore_index
    self.use_weight = use_weight
    self.upsample = upsample
    self.soft_p = sp
    self.pred_p = pp
    self.T = T
    if use_weight:
        weight = torch.FloatTensor(
            [0.8194, 0.8946, 0.9416, 1.0091, 0.9925, 0.9740, 1.0804, 1.0192, 0.8528,
             0.9771, 0.9139, 0.9744, 1.1098, 0.8883, 1.0639, 1.2476, 1.0729, 1.1323, 1.0365])
        self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
    else:
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)

    self.criterion_kd = torch.nn.KLDivLoss()

def forward(self, preds, soft):
    h, w = soft[self.soft_p].size(2), soft[self.soft_p].size(3)
    if self.upsample:
        scale_pred = F.upsample(input=preds[self.pred_p], size=(h * 8, w * 8), mode='bilinear', align_corners=True)
    else:
        scale_pred = preds[self.pred_p]
    scale_soft = F.upsample(input=soft[self.soft_p], size=(h * 8, w * 8), mode='bilinear', align_corners=True)
    loss2 = self.criterion_kd(F.log_softmax(scale_pred / self.T, dim=1), F.softmax(scale_soft / self.T, dim=1))
    return loss2
monk42 commented 5 years ago

Thanks I will try it!