irfanICMLL / structure_knowledge_distillation

The official code for the paper 'Structured Knowledge Distillation for Semantic Segmentation'. (CVPR 2019 ORAL) and extension to other tasks.
BSD 2-Clause "Simplified" License
694 stars 104 forks source link

about pair-wise and pixel-wise distillation for FCOS #27

Closed Haiyang21 closed 4 years ago

Haiyang21 commented 4 years ago

Hi, i am trying to reimplement the results for FCOS, but the details is not introduced in your paper. So could you explain it in detail ?

irfanICMLL commented 4 years ago

I add the distillation item 'pairwise' on all level of fpn-features, which works well. Besides, the pixelwise item is added to the cls-logits.

The initial loss looks like this:

2019-07-06 17:46:22,202 maskrcnn_benchmark.trainer INFO: eta: 1 day, 14:34:09 iter: 20 loss: 3.9805 (4.8004) kd_logits: 0.0289 (0.0297) loss_centerness: 0.6665 (0.6702) loss_cls: 1.0147 (1.0264) loss_reg: 1.9654 (2.7365) sd_fea: 0.3358 (0.3376) time: 1.4797 (1.5431) data: 0.0300 (0.0717) lr: 0.003333 max mem: 19266

This class will help you to build the distillation items in FCOS code.

class Sd_model(nn.Module):
    def __init__(self, s_cfg, t_cfg):
        super(Sd_model, self).__init__()
        self.teacher = GeneralizedRCNN(t_cfg)
        self.student = GeneralizedRCNN(s_cfg)
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
        self.kd_loss = CriterionKD()
        self.sd_mode = s_cfg.MODEL.SD_MODE
        self.sd_loss = CriterionSDcos_no()
        self.mimic_loss = CriterionMSE()

        self.mimic_loss = nn.ModuleList(
            [CriterionMSE(T_channle=t_cfg.MODEL.BACKBONE.OUT_CHANNELS, S_channel=s_cfg.MODEL.BACKBONE.OUT_CHANNELS) for
             _ in range(5)])

    def forward(self, images, targets=None):
        if self.training:
            with torch.no_grad():
                sd_teacher, _ = self.teacher(images, targets)
            sd_student, losses = self.student(images, targets)
            sd_loss = {}

            if 'KD' in self.sd_mode:
                logits_t = sd_teacher['cls_logits']
                logits_s = sd_student['cls_logits']
                loss = 0
                for i in range(len(logits_s)):
                    temp_loss = self.kd_loss(logits_s[i], logits_t[i])
                    loss = loss + temp_loss
                sd_loss.update({'kd_logits': loss})

            if 'SD' in self.sd_mode:
                fea_t = sd_teacher['fpn_features']
                fea_s = sd_student['fpn_features']
                loss = 0
                for i in range(1, len(fea_s)):
                    temp_loss = self.sd_loss(fea_s[i], fea_t[i])
                    loss = loss + temp_loss
                sd_loss.update({'sd_fea': loss})
            if 'mimic' in self.sd_mode:
                fea_t = sd_teacher['fpn_features']
                fea_s = sd_student['fpn_features']
                loss = 0
                for i in range(len(fea_s)):
                    temp_loss = self.mimic_loss[i](fea_s[i], fea_t[i])
                    loss = loss + temp_loss
                sd_loss.update({'mimic_fea': loss})

            losses.update(sd_loss)
            return losses

        return self.student(images)
Haiyang21 commented 4 years ago

Thanks for your reply, and i have some questions.

  1. Does it mean that kd_loss can be replaced with CriterionPixelWise, and mimic_loss can be replaced with CriterionPairWiseforWholeFeatAfterPool in your master codes ?
  2. What is the loss weight of your pixel-wise loss and pair-wise loss ?
  3. How to deal with the pixel-wise-loss for the model with only one class , since log-softmax is always 0 ?
irfanICMLL commented 4 years ago
  1. yes
  2. 1
  3. You can replace it with mse. But in our implementation, the log-softmax ('cls_logits') is not zero.
Fly-dream12 commented 4 years ago

Thanks for your code and i have some questions. 1) In your reply to the last question, you mean that kd_loss can be replaced with CriterionPixelWise, and mimic_loss can be replaced with CriterionPairWiseforWholeFeatAfterPool. I'm confused about which loss should be applied in the detection code, only CriterionPairWiseforWholeFeatAfterPool can be used for feature mimicking and cls-logits loss is appended ? 2) Meanwhile, what is self.sd_loss and self.sd_mode? Can you post your cfg file if it is convenient?

Thanks!