IDEA-Research / DAB-DETR

[ICLR 2022] Official implementation of the paper "DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR"
Apache License 2.0
501 stars 86 forks source link

Some questions about reproducing DAB-DETR #2

Closed Zx55 closed 2 years ago

Zx55 commented 2 years ago

Hi, I'd like to reproduce DAB-DETR, and I have two questions about some technique details of DAB-DETR.

i. How do you initialize the learnable anchor boxes (results in Table 2)? Why not using the results in Table 8 (random initialization and fixing them in the first decoder layer) as default setting?

ii. I am confused about modulated positional attention in Section 4.4. Is it an improvement on "conditional cross attention" in Conditional DETR (split cross attention into two parts, content and spatial dot-products)? Does the proposed modulated positional attention add referenced w into spatial dot-products?

SlongLiu commented 2 years ago

Thanks for your interest.

i. We random init them and set all parameters learnable by default, the results of which are shown in our main table. Indeed, the results in Table 8 are better than our default settings. However, we use the all learnable setting by default for a fair comparison with previous works to verify our conclusion, that is, "dynamic anchor boxes are better queries for DETR".

ii. We have admitted in the paper that our method is inspired by Conditional DETR, hence, yes it is an improvement on "conditional cross attention". The ablation study demonstrates the effectiveness of our modulated attention.

Zx55 commented 2 years ago

Thanks for reply. I'll try it.

Zx55 commented 2 years ago

Hi, I have some questions when I try to reproducing the DAB-DETR.

i. Are the parameters of MLP_csq and MLP in equation 7 are all shared across the layer?

ii. What are the MLP_csq, MLP in equation 7 and MLP used to update the anchor boxes composed of? Two submodules with a linear layer + ReLU activation?

iii. Conditional DETR applies separate linear projections on context query/key and spatial query/key in SA and CA module, respectively. Does DAB-DETR follow this way? And Conditional DETR add decoder embeddings to content query in the first decoder layer. Does DAB-DETR add anchor_sine_encoding to content_query in the first decoder layer instead?

iv. Does DAB-DETR use multi-pattern embeddings on both decoding embeddings and learnable anchor boxes?

SlongLiu commented 2 years ago

Hi,

i. Yes, they are shared across layers. ii. Yes. iii. Yes, since the first content queries are all-zero vectors. But it might be better to set the content queried learnable as well, under which you don't need to add the extra anchor_sine_encoding to content_query in the first decoder layer. iv. The multi-pattern is used for decoding embeddings only.

Zx55 commented 2 years ago

Hi, does the bbox_head in DAB-DETR predict the offset based on learnable reference anchor boxes, i.e. reference + offset? Or just predict the box coordinate directly? It's not clear in the paper.

SlongLiu commented 2 years ago

The bbox_head predicts the offset. We use normalized coordinates, hence each coordinate is a float in [0, 1]. We add the ref and offset in the inversed sigmoid space and then project them to [0,1] by sigmoid, as in Deformable DETR.

Zx55 commented 2 years ago

Thanks for reply. I try reference + offset yesterday, but it still doesn't work (mAP = 0 in testing). It seems I misunderstand some core idea in the paper. I give the simple Pytorch-like code, can you help me check what the problem is?

class Transformer:
    def __init__(self, ...):
        encoder_layer = EncoderLayer(...)
        self.encoder = Encoder(...)
        decoder_layer = DecoderLayer(...)
        self.decoder = Decoder(...)

    def forward(self, src, query, pos, anchor):
        src = src.flatten(2).permute(2, 0, 1)  # [HW, N, C]
        pos = pos.flatten(2).permute(2, 0, 1)
        query = query.unsqueeze(1).repeat(1, bs, 1)  # [num_queries, N, C]
        anchor = anchor.sigmoid().unsqueeze(1).repeat(1, bs, 1)  # [num_queries, N, 4]

        tgt = zeros_like(query)
        mem = self.encoder(src, pos)
        hs, reference = self.decoder(tgt, mem, pos, query, anchor)
        return hs, reference

class Decoder:
    def __init__(self, layer, num_layers, d_model):
        self.layers = get_clones(layer, num_layers)

        # quaternion (anchor sine encoding) -> sa_qpos/sa_kpos
        self.quaternion_head = MLP(d_model * 2, d_model)  
        self.csq_head = MLP(d_model, d_model)
        self.reference_size_head = MLP(d_model, 2)  # compute w_ref/h_ref
        self.anchor_update_head = MLP(d_model, 4)  # update learnable anchor

    def forward(self, tgt, mem, pos, query, anchor):
        output = tgt
        intermediate, intermediate_anchor = []

        for layer_id, layer in enumerate(self.layers):
            # compute PE(x, y, w, h), [num_queries, N, d_model * 2]
            anchor_sine_encoding = gen_sineembed_for_quaternion(anchor, temperature=20)  
            anchor_sa_pos = self.quaternion_head(anchor_sine_encoding)

            if layer_id == 0:  # use query as C_q (content_query) in the first decoder layer
                content_query = query
            else:
                content_query = output
            # PE(x_ref, y_ref) = csq(C_q) * PE(x, y)
            ca_qpos = self.csq_head(content_query) * anchor_sine_encoding[..., :d_model]  
            # w_ref/w & h_ref/h, [num_queries, N, 2]
            modulated_size = self.reference_size_head(content_query).sigmoid() / anchor[..., 2:]  
            # PE(x_ref) * w_ref / w & PE(y_ref) * h_ref / h
            modulated_ca_qpos = ca_qpos.view(num_queries, bs, 2, d_model // 2) * modulated_size.unsqueeze(-1)  
            output = layer(tgt, mem, pos, anchor_sa_pos, modulated_ca_qpos, is_first=(layer_id == 0))

            anchor_delta = self.anchor_update_head(output).sigmoid()  # [num_queries, bs, 4]
            anchor = anchor + anchor_delta

            ...  # update intermediate

        return stack(intermediate), stack(intermediate_anchor)

class DecoderLayer:
    def __init__(self, ...):
        ...

    def forward(self, tgt, mem, pos, anchor_sa_pos, ca_qpos, is_first):
        # self-attention, following Conditional DETR
        q_content = self.sa_qcontent_proj(tgt)
        q_pos = self.sa_qpos_proj(anchor_sa_pos)
        k_content = self.sa_kcontent_proj(tgt)
        k_pos = self.sa_kpos_proj(anchor_sa_pos)
        v = self.sa_v_proj(tgt)

        q = q_content + q_pos
        k = k_content + k_pos
        tgt = self.self_attn(q, k, value=v)[0]
        ...  # dropout and norm

        # cross-attention
        q_content = self.ca_qcontent_proj(tgt)
        k_content = self.ca_kcontent_proj(mem)
        k_pos = self.ca_kpos_proj(pos)
        v = self.ca_v_proj(mem)

        if is_first:
            # add q_pos to first decoder layer, following Conditional DETR
            q_pos = self.ca_qpos_proj(anchor_sa_pos)  
            q = q_content + q_pos
            k = k_content + k_pos
        else:
            q = q_content
            k = k_content

        q = q.view(..., nhead, d_model // nhead)
        ca_qpos = self.ca_qpos_proj(ca_qpos).view(..., nhead, d_model // nhead)
        q = cat([q, ca_qpos], dim=3)view(..., d_model * 2)  # concat ca_qcontent and ca_qpos
        ...  # concat k

        tgt = self.cross_attn(q, k, value=v)[0]
        ...  # ffn, dropout and norm
        return tgt
SlongLiu commented 2 years ago

Hello, our code is available now, which provides more details about our paper.