IDEA-Research / detrex

detrex is a research platform for DETR-based object detection, segmentation, pose estimation and other visual recognition tasks.
https://detrex.readthedocs.io/en/latest/
Apache License 2.0
1.9k stars 199 forks source link

关于two-stage topk_proposals的问题 #337

Open nhw649 opened 5 months ago

nhw649 commented 5 months ago

作者你好,我查看了deformable-detr和dino的two-stage代码,两个用于计算proposals得分的方式不同,请问哪一个效果更好?

deformable-detr筛选proposals的实现如下: topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] encoder输出的损失为:

if 'enc_outputs' in outputs:
    enc_outputs = outputs['enc_outputs']
    bin_targets = copy.deepcopy(targets)
    for bt in bin_targets:
        bt['labels'] = torch.zeros_like(bt['labels'])
    indices = self.matcher(enc_outputs, bin_targets)
    for loss in self.losses:
        if loss == 'masks':
            # Intermediate masks losses are too costly to compute, we ignore them.
            continue
        kwargs = {}
        if loss == 'labels':
            # Logging is enabled only for the last layer
            kwargs['log'] = False
        l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, indicators, **kwargs)
        l_dict = {k + f'_enc': v for k, v in l_dict.items()}
        losses.update(l_dict)

dino筛选proposals的实现如下: topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] encoder输出的损失为:其中self.two_stage_binary_cls为False

if "enc_outputs" in outputs:
    # {"pred_logits": [B, two_stage_num_proposals, C], "pred_boxes": [B, two_stage_num_proposals, 4]}
    enc_outputs = outputs["enc_outputs"]
    if self.two_stage_binary_cls:
        for bt in targets:
            bt["labels"] = torch.zeros_like(bt["labels"])
    indices = self.matcher(enc_outputs, targets)
    for loss in self.losses:
        l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes)
        l_dict = {k + "_enc": v for k, v in l_dict.items()}
        losses.update(l_dict)