Closed xiaoche-24 closed 10 months ago
下面箭头所指的expand_label函数里,target首先是定义了一个(N,C)的全零矩阵,下面一行是要对target里面的值全部赋值为1嘛?为什么还要索引gt_classes呢?gt_class不也是一个(N,C)矩阵嘛,这个地方感觉是个语法错误啊 ` def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): self.n_i, self.n_c = cls_score.size()
self.gt_classes = label self.pred_class_logits = cls_score def expand_label(pred, gt_classes): target = pred.new_zeros(self.n_i, self.n_c) target[torch.arange(self.n_i), gt_classes] = 1 return target
---> target = expand_label(cls_score, label)
pos_w, neg_w = self.get_weight(cls_score) weight = pos_w * target + neg_w * (1 - target) cls_loss = F.binary_cross_entropy_with_logits(cls_score, target, reduction='none') cls_loss = torch.sum(cls_loss * weight) / self.n_i self.collect_grad(cls_score.detach(), target.detach(), weight.detach()) return self.loss_weight * cls_loss`
下面箭头所指的expand_label函数里,target首先是定义了一个(N,C)的全零矩阵,下面一行是要对target里面的值全部赋值为1嘛?为什么还要索引gt_classes呢?gt_class不也是一个(N,C)矩阵嘛,这个地方感觉是个语法错误啊 ` def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): self.n_i, self.n_c = cls_score.size()
---> target = expand_label(cls_score, label)