fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.14k stars 513 forks source link

Error Separting Backbone from Deformable DETR #205

Open victoic opened 1 year ago

victoic commented 1 year ago

Hello, hope you're doing well. I have a question that's not directly on the scope of your repo, but would appreciate any direction on what's going wrong. I'm trying to separate the backbone (ResNet 50) and detector (Deformable DETR), so that later I can work on some modification more easily. As such, my backbone does not utilizes NestedTensor and the output features from it are converted to NestedTensor when received by the Deformable DETR. A few modifications were made on the Deformable DETR side:

When receiving a (2x3x69x69) x dummy input (filled with ones) I get an error at tmp[..., :2] += reference: RuntimeError: output with shape [100, 2] doesn't match the broadcast shape [2, 100, 2]. I do not understand why. My features have all expected shapes: 0 torch.Size([2, 64, 69, 69]) 1 torch.Size([2, 256, 69, 69]) 2 torch.Size([2, 512, 35, 35]) 3 torch.Size([2, 1024, 18, 18]) 4 torch.Size([2, 2048, 9, 9])

Below is the DeformableDETR class as modified with documentation and comments removed for compacting.

class DeformableDETR(BaseDetector):
    def __init__(self, channels, num_feature_levels, num_classes, num_queries,
                 aux_loss=True, with_box_refine=False, two_stage=False, **kwargs):
        args = kwargs.pop('args')
        hidden_dim = kwargs.pop('hidden_dim')
        position_embedding = kwargs.pop('position_embedding')
        super(DeformableDETR, self).__init__(**kwargs)

        self.position_encoding = build_position_encoding(hidden_dim, position_embedding)

        self.num_queries = num_queries
        transformer = build_deforamble_transformer(args=args)
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.num_feature_levels = num_feature_levels
        if not two_stage:
            self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
        if num_feature_levels > 1:
            num_backbone_outs = len(channels)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(channels[0], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage = two_stage

        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)

        num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
        if with_box_refine:
            self.class_embed = _get_clones(self.class_embed, num_pred)
            self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
            nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
            # hack implementation for iterative bounding box refinement
            self.transformer.decoder.bbox_embed = self.bbox_embed
        else:
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
            self.transformer.decoder.bbox_embed = None
        if two_stage:
            self.transformer.decoder.class_embed = self.class_embed
            for box_embed in self.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)

    def forward(self, x: torch.Tensor, features: List[torch.Tensor], *args):
        out = []
        pos = []

        for f in features:
            nested_x = nested_tensor_from_tensor_list(f)
            mask = F.interpolate(nested_x.mask[None].float(), size=f.shape[-2:]).to(torch.bool)[0]
            out.append(NestedTensor(f, mask))
            pos.append(self.position_encoding(out[-1]).to(out[-1].tensors.dtype))

        srcs = []
        masks = []
        for l, feat in enumerate(out):
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None

        query_embeds = None
        if not self.two_stage:
            query_embeds = self.query_embed.weight
        hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)

        outputs_classes = []
        outputs_coords = []
        for lvl in range(len(features)):
            print(lvl, features[lvl].shape)
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.class_embed[lvl](hs[lvl])
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                tmp[..., :2] += reference # <<< ERROR happens here
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
        return out