tztztztztz / eqlv2

The official implementation of Equalization Loss v1 & v2 (CVPR 2020, 2021) based on MMDetection. https://arxiv.org/abs/2012.08548 https://arxiv.org/abs/2003.05176
Apache License 2.0
155 stars 22 forks source link

eqlv2.py是不是有语法问题? #30

Closed xiaoche-24 closed 7 months ago

xiaoche-24 commented 7 months ago

下面箭头所指的expand_label函数里,target首先是定义了一个(N,C)的全零矩阵,下面一行是要对target里面的值全部赋值为1嘛?为什么还要索引gt_classes呢?gt_class不也是一个(N,C)矩阵嘛,这个地方感觉是个语法错误啊 ` def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): self.n_i, self.n_c = cls_score.size()

    self.gt_classes = label
    self.pred_class_logits = cls_score

    def expand_label(pred, gt_classes):
        target = pred.new_zeros(self.n_i, self.n_c)
        target[torch.arange(self.n_i), gt_classes] = 1
        return target

---> target = expand_label(cls_score, label)

    pos_w, neg_w = self.get_weight(cls_score)

    weight = pos_w * target + neg_w * (1 - target)

    cls_loss = F.binary_cross_entropy_with_logits(cls_score, target,
                                                  reduction='none')
    cls_loss = torch.sum(cls_loss * weight) / self.n_i

    self.collect_grad(cls_score.detach(), target.detach(), weight.detach())

    return self.loss_weight * cls_loss`