Fangyi-Chen / SQR

MIT License
102 stars 5 forks source link

about Dense Query Recollection and Recurrence #12

Open MinorityA opened 4 months ago

MinorityA commented 4 months ago

Hi, thank you for your amazing work.

I am particularly interested in the section on DQRR which your team has implemented on AdamXier. I am curious to know whether your team has also tested this on DETR, and if so, what the results were. Any details you could share would be greatly appreciated.

Fangyi-Chen commented 4 months ago

I did not test that, but I think it will work on DETR as well . Thanks.

MinorityA commented 3 months ago

I did not test that, but I think it will work on DETR as well . Thanks.

Hello! Could you please provide more information or guidance on how to correctly implement this step? Since I trained it on dab-deformable-detr following the instruction of the paper, feeding back outputs of layer 6 to itself again, but the AP result of evaluation when I shared the parameters of layer 6 for all layers was nearly 0.

Fangyi-Chen commented 2 months ago

Hi, I get a draft implementation (but correct and runable) of DQRR on adamixer

In the implementation, fakesetsize is for acceleration and you can ignore it

if stage == self.num_stages, i.e., at the latest stage, we treat it differently from the other stages

During testing, you can use the last stage only.

def forward_train(self,
                  x,
                  query_xyzr,
                  query_content,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None,
                  imgs_whwh=None,
                  gt_masks=None):

    num_imgs = len(img_metas)
    num_queries = query_xyzr.size(1)
    imgs_whwh_keep = imgs_whwh.repeat(1, num_queries, 1)
    all_stage_bbox_results = []
    all_stage_loss = {}

    query_xyzr_list_reserve = [query_xyzr]
    query_content_list_reserve = [query_content]
    query_xyzr_list_reserve_last = []
    query_content_list_reserve_last = []

    batchsize = len(img_metas)
    fakesetsize = 2  # 8 will reduce 16 hours; 2 will reduce 9 hours; 4 will reduce 15 hours
    x_keep = [_ for _ in x]
    img_metas_keep = img_metas.copy()
    gt_bboxes_keep = gt_bboxes.copy()
    gt_labels_keep = gt_labels.copy()
    for stage in range(self.num_stages+1):

        if stage == self.num_stages: # at the latest stage
            query_xyzr = torch.cat(query_xyzr_list_reserve_last, dim=0)
            query_content = torch.cat(query_content_list_reserve_last, dim=0)
            setsize = int(len(query_content) / batchsize)
            if setsize > fakesetsize:
                single_stage_group_loss = []
                num_group = int(setsize / fakesetsize)

                for groupid in range(num_group):
                    query_xyzr_this_group = query_xyzr[fakesetsize * batchsize * groupid:fakesetsize * batchsize * (
                                groupid + 1)]
                    query_content_this_group = query_content[
                                               fakesetsize * batchsize * groupid:fakesetsize * batchsize * (
                                                           groupid + 1)]
                    bbox_results = self._bbox_forward(stage-1, x, query_xyzr_this_group, query_content_this_group,
                                                      img_metas)
                    # all_stage_bbox_results.append(bbox_results)
                    if gt_bboxes_ignore is None:
                        # TODO support ignore
                        gt_bboxes_ignore = [None for _ in range(num_imgs)]
                    sampling_results = []
                    cls_pred_list = bbox_results['detach_cls_score_list']
                    bboxes_list = bbox_results['detach_bboxes_list']

                    for i in range(num_imgs * fakesetsize):
                        normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                                  imgs_whwh[i])
                        assign_result = self.bbox_assigner[stage-1].assign(
                            normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                            gt_labels[i], img_metas[i])
                        sampling_result = self.bbox_sampler[stage-1].sample(
                            assign_result, bboxes_list[i], gt_bboxes[i])
                        sampling_results.append(sampling_result)
                    bbox_targets = self.bbox_head[stage-1].get_targets(
                        sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage-1],
                        True)

                    cls_score = bbox_results['cls_score']
                    decode_bbox_pred = bbox_results['decode_bbox_pred']

                    single_stage_group_loss.append(self.bbox_head[stage-1].loss(
                        cls_score.view(-1, cls_score.size(-1)),
                        decode_bbox_pred.view(-1, 4),
                        *bbox_targets,
                        imgs_whwh=imgs_whwh)
                    )

                # TODO: weight group loss: for the most important group weight it the highest
                # TODO: multiply fakesetsize for each loss or not multiply?  Do not forget to modify the setsize below
                for groupid, single_stage_single_group_loss in enumerate(single_stage_group_loss):
                    if groupid == 0:
                        for key, value in single_stage_single_group_loss.items():
                            all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                                    self.stage_loss_weights[stage-1] * fakesetsize
                    else:
                        for key, value in single_stage_single_group_loss.items():
                            all_stage_loss[f'stage{stage}_{key}'] += value * \
                                                                     self.stage_loss_weights[stage-1] * fakesetsize
            else:

                bbox_results = self._bbox_forward(stage-1, x, query_xyzr, query_content,
                                                  img_metas)
                all_stage_bbox_results.append(bbox_results)
                if gt_bboxes_ignore is None:
                    # TODO support ignore
                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
                sampling_results = []
                cls_pred_list = bbox_results['detach_cls_score_list']
                bboxes_list = bbox_results['detach_bboxes_list']

                for i in range(num_imgs * setsize):
                    normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                              imgs_whwh[i])
                    assign_result = self.bbox_assigner[stage-1].assign(
                        normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                        gt_labels[i], img_metas[i])
                    sampling_result = self.bbox_sampler[stage-1].sample(
                        assign_result, bboxes_list[i], gt_bboxes[i])
                    sampling_results.append(sampling_result)
                bbox_targets = self.bbox_head[stage-1].get_targets(
                    sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage-1],
                    True)

                cls_score = bbox_results['cls_score']
                decode_bbox_pred = bbox_results['decode_bbox_pred']

                single_stage_group_loss = self.bbox_head[stage-1].loss(
                    cls_score.view(-1, cls_score.size(-1)),
                    decode_bbox_pred.view(-1, 4),
                    *bbox_targets,
                    imgs_whwh=imgs_whwh)

                # TODO: multiply setsize for each loss or not multiply?  Do not forget to modify the fakesetsize above
                for key, value in single_stage_group_loss.items():
                    all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                            self.stage_loss_weights[stage-1] * setsize

            return all_stage_loss

        query_xyzr = torch.cat(query_xyzr_list_reserve, dim=0)
        query_content = torch.cat(query_content_list_reserve, dim=0)
        setsize = int(len(query_content) / batchsize)

        if setsize > fakesetsize:
            single_stage_group_loss = []
            num_group = int(setsize / fakesetsize)

            x = [x_.repeat(fakesetsize, 1, 1, 1) for x_ in x_keep]
            img_metas = img_metas_keep * fakesetsize
            gt_bboxes = gt_bboxes_keep * fakesetsize
            gt_labels = gt_labels_keep * fakesetsize
            imgs_whwh = imgs_whwh_keep.repeat(fakesetsize, 1, 1)

            for groupid in range(num_group):
                query_xyzr_this_group = query_xyzr[fakesetsize*batchsize*groupid:fakesetsize*batchsize*(groupid+1)]
                query_content_this_group = query_content[fakesetsize*batchsize*groupid:fakesetsize*batchsize*(groupid+1)]
                bbox_results = self._bbox_forward(stage, x, query_xyzr_this_group, query_content_this_group,
                                                  img_metas)
                # all_stage_bbox_results.append(bbox_results)
                if gt_bboxes_ignore is None:
                    # TODO support ignore
                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
                sampling_results = []
                cls_pred_list = bbox_results['detach_cls_score_list']
                bboxes_list = bbox_results['detach_bboxes_list']

                query_xyzr_new = bbox_results['query_xyzr'].detach()
                query_content_new = bbox_results['query_content']
                # TODO: detach query content for noisy querys because not going to use them anyway?
                # TODO: only append important query groups, e.x. from the last layer
                if stage == self.num_stages - 1:
                    query_xyzr_list_reserve_last.append(query_xyzr_new)
                    query_content_list_reserve_last.append(query_content_new)
                else:
                    query_xyzr_list_reserve.append(query_xyzr_new)
                    query_content_list_reserve.append(query_content_new)

                for i in range(num_imgs * fakesetsize):
                    normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                              imgs_whwh[i])
                    assign_result = self.bbox_assigner[stage].assign(
                        normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                        gt_labels[i], img_metas[i])
                    sampling_result = self.bbox_sampler[stage].sample(
                        assign_result, bboxes_list[i], gt_bboxes[i])
                    sampling_results.append(sampling_result)
                bbox_targets = self.bbox_head[stage].get_targets(
                    sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
                    True)

                cls_score = bbox_results['cls_score']
                decode_bbox_pred = bbox_results['decode_bbox_pred']

                single_stage_group_loss.append(self.bbox_head[stage].loss(
                    cls_score.view(-1, cls_score.size(-1)),
                    decode_bbox_pred.view(-1, 4),
                    *bbox_targets,
                    imgs_whwh=imgs_whwh)
                )

            # TODO: weight group loss: for the most important group weight it the highest
            # TODO: multiply fakesetsize for each loss or not multiply?  Do not forget to modify the setsize below
            for groupid, single_stage_single_group_loss in enumerate(single_stage_group_loss):
                if groupid == 0:
                    for key, value in single_stage_single_group_loss.items():
                        all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                                self.stage_loss_weights[stage] * fakesetsize
                else:
                    for key, value in single_stage_single_group_loss.items():
                        all_stage_loss[f'stage{stage}_{key}'] += value * \
                                                                 self.stage_loss_weights[stage] * fakesetsize
        else:
            x = [x_.repeat(setsize, 1, 1, 1) for x_ in x_keep]
            img_metas = img_metas_keep * setsize
            gt_bboxes = gt_bboxes_keep * setsize
            gt_labels = gt_labels_keep * setsize
            imgs_whwh = imgs_whwh_keep.repeat(setsize, 1, 1)

            bbox_results = self._bbox_forward(stage, x, query_xyzr, query_content,
                                              img_metas)
            all_stage_bbox_results.append(bbox_results)
            if gt_bboxes_ignore is None:
                # TODO support ignore
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            cls_pred_list = bbox_results['detach_cls_score_list']
            bboxes_list = bbox_results['detach_bboxes_list']

            query_xyzr_new = bbox_results['query_xyzr'].detach()
            query_content_new = bbox_results['query_content']
            # TODO: detach query content for noisy querys because not going to use them anyway?
            # TODO: only append important query groups, e.x. from the last layer
            if stage == self.num_stages - 1:
                query_xyzr_list_reserve_last.append(query_xyzr_new)
                query_content_list_reserve_last.append(query_content_new)
            else:
                query_xyzr_list_reserve.append(query_xyzr_new)
                query_content_list_reserve.append(query_content_new)

            for i in range(num_imgs * setsize):
                normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                          imgs_whwh[i])
                assign_result = self.bbox_assigner[stage].assign(
                    normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                    gt_labels[i], img_metas[i])
                sampling_result = self.bbox_sampler[stage].sample(
                    assign_result, bboxes_list[i], gt_bboxes[i])
                sampling_results.append(sampling_result)
            bbox_targets = self.bbox_head[stage].get_targets(
                sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
                True)

            cls_score = bbox_results['cls_score']
            decode_bbox_pred = bbox_results['decode_bbox_pred']

            single_stage_group_loss = self.bbox_head[stage].loss(
                cls_score.view(-1, cls_score.size(-1)),
                decode_bbox_pred.view(-1, 4),
                *bbox_targets,
                imgs_whwh=imgs_whwh)

            # TODO: multiply setsize for each loss or not multiply?  Do not forget to modify the fakesetsize above
            for key, value in single_stage_group_loss.items():
                all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                        self.stage_loss_weights[stage] * setsize

    #print(all_stage_loss)
    #print(all_stage_lossa)
Fangyi-Chen commented 2 months ago

I did not test that, but I think it will work on DETR as well . Thanks.

Hello! Could you please provide more information or guidance on how to correctly implement this step? Since I trained it on dab-deformable-detr following the instruction of the paper, feeding back outputs of layer 6 to itself again, but the AP result of evaluation when I shared the parameters of layer 6 for all layers was nearly 0.

I'm not sure, but did you make the order of post-norm/pre-norm correctly when you implement the recurrence?