jozhang97 / DETA

Detection Transformers with Assignment
Apache License 2.0
241 stars 20 forks source link

Conflicting label assign results. #8

Closed ligang-cs closed 1 year ago

ligang-cs commented 1 year ago

Thank you for your great work! I'm confused about the codes of label assignment: function sample_topk_per_gt().

gt_inds2, counts = gt_inds.unique(return_counts=True)    
scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1)
gt_inds2 = gt_inds2[:,None].repeat(1, k)
pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])

From the above codes, I guess that one object query will be matched to multiple ground truths, resulting in conflicting label assign results.

jozhang97 commented 1 year ago

Hi @ligang-cs ,

Sorry for the late reply. You're right, this will sometimes assign one prediction to multiple objects. Thank you for catching this bug.

I reimplemented it

def sample_topk_per_gt(pr_inds, gt_inds, iou, k):
    if len(gt_inds) == 0:
        return pr_inds, gt_inds
    pr_inds3, gt_inds3 = [], []
    for gt in gt_inds.unique():
        pr_for_gt = pr_inds[gt_inds == gt]
        if len(pr_for_gt) <= k:
            pr_inds3.append(pr_for_gt)
            gt_inds3.append(gt_inds[gt_inds == gt])
            continue
        _, inds = iou[gt, pr_for_gt].topk(k, dim=0)
        pr_inds3.append(pr_for_gt[inds[:k]])
        gt_inds3.append(gt_inds[gt_inds == gt][:k])
    pr_inds3, gt_inds3 =  torch.cat(pr_inds3), torch.cat(gt_inds3)
    # print(pr_inds.shape, pr_inds.unique().shape, pr_inds3.shape, pr_inds3.unique().shape)
    return pr_inds3, gt_inds3

When I trained this, I get 50.2AP on a 1x schedule (previously 50.1).