lhoyer / DAFormer

[CVPR22] Official Implementation of DAFormer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation
Other
466 stars 92 forks source link

feat_loss.backward() and Thing-Class ImageNet Feature Distance (FD) #37

Closed creater-zq closed 2 years ago

creater-zq commented 2 years ago

Hello!

thanks!

            feat_loss, feat_log = self.calc_feat_dist(img, gt_semantic_seg,
                                                      src_feat)
            feat_loss.backward()
    def calc_feat_dist(self, img, gt, feat=None):
        assert self.enable_fdist
        with torch.no_grad():
            self.get_imnet_model().eval()
            feat_imnet = self.get_imnet_model().extract_feat(img)

            feat_imnet = [f.detach() for f in feat_imnet]  # ?

        lay = -1
        if self.fdist_classes is not None:
            fdclasses = torch.tensor(self.fdist_classes, device=gt.device)
            scale_factor = gt.shape[-1] // feat[lay].shape[-1]
            gt_rescaled = downscale_label_ratio(gt, scale_factor,
                                                self.fdist_scale_min_ratio,
                                                self.num_classes,
                                                255).long().detach()
            fdist_mask = torch.any(gt_rescaled[..., None] == fdclasses, -1)

            # ?
            feat_dist = self.masked_feat_dist(feat[lay], feat_imnet[lay],
                                              fdist_mask)

          feat_dist = self.fdist_lambda * feat_dist

          feat_loss, feat_log = self._parse_losses(
              {'loss_imnet_feat_dist': feat_dist})
          feat_log.pop('loss', None)
          return feat_loss, feat_log
    def masked_feat_dist(self, f1, f2, mask=None):
        feat_diff = f1 - f2

        pw_feat_dist = torch.norm(feat_diff, dim=1, p=2)  # f1, f2?

        if mask is not None:

            pw_feat_dist = pw_feat_dist[mask.squeeze(1)]

        return torch.mean(pw_feat_dist)
lhoyer commented 2 years ago

Yes, the weights of the ImageNet model are intentionally detached from the compute graph so that the ImageNet model is not updated because we want to calculate the distance of the student features to the features of the original/frozen ImageNet model.

creater-zq commented 2 years ago

Yes, the weights of the ImageNet model are intentionally detached from the compute graph so that the ImageNet model is not updated because we want to calculate the distance of the student features to the features of the original/frozen ImageNet model.

Thanks!