megvii-research / mdistiller

The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf
807 stars 123 forks source link

关于_get_gt_mask() #24

Closed wxfaaaaa closed 1 year ago

wxfaaaaa commented 2 years ago

作者您好,我不太理解您的您的_get_gt_mask()这个方法,希望您可以帮我解疑答惑! 您的代码中: def _get_gt_mask(logits, target): target = target.reshape(-1) mask = torch.zeroslike(logits).scatter(1, target.unsqueeze(1), 1).bool() return mask 但是 def scatter_(self, dim, index, src, reduce=None): scatter_的src是一个tensor,为什么您这里只写了一个1,而且我运行到这里的时候也会报错,还有为什么要target = target.reshape(-1).unsqueeze(1),这样的话target和logits的shape不就不一样了吗?还请您解答。

Zzzzz1 commented 2 years ago

这个method最终返回就是一个和logits的shape相同的mask,1代表是gt class,其余全0。

  1. target其实是每个样本的label index,因此shape是B(batch size), unsqueeze(1)目的是让target和logits有一样的number of dimensions;
  2. torch1.9版本里面,scatter的src可以是一个值的,这个如果会出现报错可能和torch的版本有关。