The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf
作者您好,我不太理解您的您的_get_gt_mask()这个方法,希望您可以帮我解疑答惑! 您的代码中: def _get_gt_mask(logits, target): target = target.reshape(-1) mask = torch.zeroslike(logits).scatter(1, target.unsqueeze(1), 1).bool() return mask 但是 def scatter_(self, dim, index, src, reduce=None): scatter_的src是一个tensor,为什么您这里只写了一个1,而且我运行到这里的时候也会报错,还有为什么要target = target.reshape(-1).unsqueeze(1),这样的话target和logits的shape不就不一样了吗?还请您解答。