CaptainEven / MCMOT

Real time one-stage multi-class & multi-object tracking based on anchor-free detection and ReID
MIT License
383 stars 82 forks source link

should gen id together in all class for multi-class dataset? #96

Open upsx opened 2 years ago

upsx commented 2 years ago

i have looked the motloss function,the re-id loss is implemented by a classifier. if gen id respectively in multi-clss, then diffierent cls obj have same id, so the in classifier the two same id for different cls obj is consider the same class. therefore, i think this is a error. @Even

class MotLoss(torch.nn.Module):                                                         # loss网络
    def __init__(self, opt):
        super(MotLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()                 # heatmap loss: FocalLoss
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None                                # reg loss: Reg1Loss
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
                RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg               # wh loss: Reg1Loss
        self.opt = opt
        self.emb_dim = opt.reid_dim                                                     # 64 or 128
        self.nID = opt.nID                                                              # 数据集的id数目
        self.classifier = nn.Linear(self.emb_dim, self.nID)                             # re-id的分类器: 64->nID
        if opt.id_loss == 'focal':                                                      # ce: False
            torch.nn.init.normal_(self.classifier.weight, std=0.01)
            prior_prob = 0.01
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            torch.nn.init.constant_(self.classifier.bias, bias_value)
        self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)                              # re-id loss: CrossEntropyLoss,忽略id为-1的样本
        self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)                          # ??
        self.s_det = nn.Parameter(-1.85 * torch.ones(1))                                # detect loss weight
        self.s_id = nn.Parameter(-1.05 * torch.ones(1))                                 # re-id loss weight      

    def forward(self, outputs, batch):                                                  # outputs: model的heads输出,batch: {'input': imgs, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh, 'reg': reg, 'ids': ids, 'bbox': bbox_xys}
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):                                                 # opt.num_statcks = 1
            output = outputs[s]                                                         # output: model的heads输出,{'hm':hm, 'wh':wh, 'reg':reg, 'id':id}
            if not opt.mse_loss:                                                        # True
                output['hm'] = _sigmoid(output['hm'])                                   # 对heatmap sigmoid计算

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks            # hm loss
            if opt.wh_weight > 0:                                                       # wh loss: 0.1
                wh_loss += self.crit_reg(
                    output['wh'], batch['reg_mask'],
                    batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:                                   # reg loss: True and 1
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks  # 

            if opt.id_weight > 0:                                                       # re-id loss: 1
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                if self.opt.id_loss == 'focal':
                    id_target_one_hot = id_output.new_zeros((id_head.size(0), self.nID)).scatter_(1,
                                                                                                  id_target.long().view(
                                                                                                      -1, 1), 1)
                    id_loss += sigmoid_focal_loss_jit(id_output, id_target_one_hot,
                                                      alpha=0.25, gamma=2.0, reduction="sum"
                                                      ) / id_output.size(0)
                else:
                    id_loss += self.IDLoss(id_output, id_target)                        # True

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss    # detect loss = heatmapLoss + wh_loss + reg_loss(为什么不把4个head权重也设置成可学习参数)
        if opt.multi_loss == 'uncertainty':                                                         # True: loss = det_loss + re-id_loss
            loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)   # 可学习
            loss *= 0.5
        else:
            loss = det_loss + 0.1 * id_loss                                                                         # 固定

        loss_stats = {'loss': loss, 'hm_loss': hm_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}                                 # loss: final total loss, loss_stats: 4 heads loss
        return loss, loss_stats                                                                                     
CaptainEven commented 2 years ago

@upsx You have misunderstood. We use multi-classifier rather than a classifier, as a result different classes use different IDs.

upsx commented 2 years ago

@upsx You have misunderstood. We use multi-classifier rather than a classifier, as a result different classes use different IDs.

oh,tkx for you reply。i have an additional question that whether you have done experiment to campare this next two options which is better:

  1. gen id together for multi class and use single classifier
  2. gen id respectively for multi class and use multi classifier
CaptainEven commented 2 years ago

@upsx I think it may not have much difference from the view of MOT, but may cause confusion if all classes share the IDs.

upsx commented 2 years ago

@upsx I think it may not have much difference from the view of MOT, but may cause confusion if all classes share the IDs.

ok, tks :sparkling_heart:. i prepare to verify this problem.