lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
2.21k stars 242 forks source link

迁移RT-DETR至其它 #166

Open hanzifan opened 8 months ago

hanzifan commented 8 months ago

我尝试用RT-DETR替换MOTR中的Deformable-DETR,但是出现了loss无法收敛至较低的值。 代码是这样的:

        # Backbone & Encoder
        fpn_features, pos = self.backbone(samples.tensors.squeeze(0), samples.mask)
        features = [feature.tensors for feature in fpn_features]
        # for feature in features:
        #     print(feature.shape)
        srcs = self.encoder(features[1:])

        # Decoder
        # input projection and embedding
        (memory, spatial_shapes, level_start_index) = self._get_encoder_input(srcs)

        # prepare denoising training
        denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
            self._get_decoder_input(track_instances, memory, spatial_shapes, denoising_class, denoising_bbox_unact)
        track_instances.ref_points = init_ref_points_unact.squeeze(0)

        # decoder
        outputs_coord, outputs_class, hs, inter_references = self.decoder(target,
                                                                          init_ref_points_unact,
                                                                          memory,
                                                                          spatial_shapes,
                                                                          level_start_index,
                                                                          self.bbox_embed,
                                                                          self.class_embed,
                                                                          self.query_pos_head,
                                                                          attn_mask=attn_mask)

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        out['hs'] = hs[-1]
        return out

同时为了能跟MOTR对其,我对_get_decoder_input进行了修改:

def _get_decoder_input(self,
                           track_instances,
                           memory,
                           spatial_shapes,
                           denoising_class=None,
                           denoising_bbox_unact=None):
        bs, _, _ = memory.shape
        # prepare input for decoder
        if self.training or self.eval_spatial_size is None:
            anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
        else:
            anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)

        # memory = torch.where(valid_mask, memory, 0)
        memory = valid_mask.to(memory.dtype) * memory  # TODO fix type error for onnx export 

        output_memory = self.enc_output(memory)

        enc_outputs_class = self.enc_score_head(output_memory)
        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors

        _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
        # _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, target.shape[1], dim=1)

        # extract region features
        if self.learnt_init_query:
            target = track_instances.query_pos.unsqueeze(0)
        else:
            target = output_memory.gather(dim=1, \
                index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
            if len(track_instances[self.num_queries:]) > 0:
                target = torch.cat([target, track_instances[self.num_queries:].query_pos.unsqueeze(0)], dim=1)
            target = target.detach()

        reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
        if len(track_instances[self.num_queries:]) > 0:
            reference_points_unact = torch.cat([reference_points_unact, track_instances[self.num_queries:].ref_points.unsqueeze(0)], dim=1)

        enc_topk_bboxes = F.sigmoid(reference_points_unact)
        if denoising_bbox_unact is not None:
            reference_points_unact = torch.cat(
                [denoising_bbox_unact, reference_points_unact], 1)

        enc_topk_logits = enc_outputs_class.gather(dim=1, \
            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))

        if denoising_class is not None:
            target = torch.cat([denoising_class, target], 1)

        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits

计算loss时采取了focal loss。然后这是我的loss的情况: 1703068753708

hanzifan commented 8 months ago

我没有用 IOU-aware selector,使用的是传统的可训练的query

lebron-2016 commented 5 months ago

我尝试用RT-DETR替换MOTR中的Deformable-DETR,但是出现了loss无法收敛至较低的值。 代码是这样的:

        # Backbone & Encoder
        fpn_features, pos = self.backbone(samples.tensors.squeeze(0), samples.mask)
        features = [feature.tensors for feature in fpn_features]
        # for feature in features:
        #     print(feature.shape)
        srcs = self.encoder(features[1:])

        # Decoder
        # input projection and embedding
        (memory, spatial_shapes, level_start_index) = self._get_encoder_input(srcs)

        # prepare denoising training
        denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
            self._get_decoder_input(track_instances, memory, spatial_shapes, denoising_class, denoising_bbox_unact)
        track_instances.ref_points = init_ref_points_unact.squeeze(0)

        # decoder
        outputs_coord, outputs_class, hs, inter_references = self.decoder(target,
                                                                          init_ref_points_unact,
                                                                          memory,
                                                                          spatial_shapes,
                                                                          level_start_index,
                                                                          self.bbox_embed,
                                                                          self.class_embed,
                                                                          self.query_pos_head,
                                                                          attn_mask=attn_mask)

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        out['hs'] = hs[-1]
        return out

同时为了能跟MOTR对其,我对_get_decoder_input进行了修改:

def _get_decoder_input(self,
                           track_instances,
                           memory,
                           spatial_shapes,
                           denoising_class=None,
                           denoising_bbox_unact=None):
        bs, _, _ = memory.shape
        # prepare input for decoder
        if self.training or self.eval_spatial_size is None:
            anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
        else:
            anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)

        # memory = torch.where(valid_mask, memory, 0)
        memory = valid_mask.to(memory.dtype) * memory  # TODO fix type error for onnx export 

        output_memory = self.enc_output(memory)

        enc_outputs_class = self.enc_score_head(output_memory)
        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors

        _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
        # _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, target.shape[1], dim=1)

        # extract region features
        if self.learnt_init_query:
            target = track_instances.query_pos.unsqueeze(0)
        else:
            target = output_memory.gather(dim=1, \
                index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
            if len(track_instances[self.num_queries:]) > 0:
                target = torch.cat([target, track_instances[self.num_queries:].query_pos.unsqueeze(0)], dim=1)
            target = target.detach()

        reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
        if len(track_instances[self.num_queries:]) > 0:
            reference_points_unact = torch.cat([reference_points_unact, track_instances[self.num_queries:].ref_points.unsqueeze(0)], dim=1)

        enc_topk_bboxes = F.sigmoid(reference_points_unact)
        if denoising_bbox_unact is not None:
            reference_points_unact = torch.cat(
                [denoising_bbox_unact, reference_points_unact], 1)

        enc_topk_logits = enc_outputs_class.gather(dim=1, \
            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))

        if denoising_class is not None:
            target = torch.cat([denoising_class, target], 1)

        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits

计算loss时采取了focal loss。然后这是我的loss的情况: 1703068753708

您好,请问解决了吗?

lebron-2016 commented 5 months ago

我没有用 IOU-aware selector,使用的是传统的可训练的query

您好,我也遇到了类似的问题,收敛速度比Deformable DETR慢很多并且最终的损失不能下降到一个较低的值,请问您解决了吗?