Closed monk42 closed 5 years ago
Just insert this part in your training code:
class CriterionKD(nn.Module): ''' knowledge distillation loss '''
def __init__(self, ignore_index=255, upsample=False, use_weight=True, T=1, sp=0, pp=0):
super(CriterionKD, self).__init__()
self.ignore_index = ignore_index
self.use_weight = use_weight
self.upsample = upsample
self.soft_p = sp
self.pred_p = pp
self.T = T
if use_weight:
weight = torch.FloatTensor(
[0.8194, 0.8946, 0.9416, 1.0091, 0.9925, 0.9740, 1.0804, 1.0192, 0.8528,
0.9771, 0.9139, 0.9744, 1.1098, 0.8883, 1.0639, 1.2476, 1.0729, 1.1323, 1.0365])
self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
else:
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
self.criterion_kd = torch.nn.KLDivLoss()
def forward(self, preds, soft):
h, w = soft[self.soft_p].size(2), soft[self.soft_p].size(3)
if self.upsample:
scale_pred = F.upsample(input=preds[self.pred_p], size=(h * 8, w * 8), mode='bilinear', align_corners=True)
else:
scale_pred = preds[self.pred_p]
scale_soft = F.upsample(input=soft[self.soft_p], size=(h * 8, w * 8), mode='bilinear', align_corners=True)
loss2 = self.criterion_kd(F.log_softmax(scale_pred / self.T, dim=1), F.softmax(scale_soft / self.T, dim=1))
return loss2
Thanks I will try it!
which part is the codes of Pixel-wise distillation , Thanks!