henghuiding / Vision-Language-Transformer

[ICCV2021 & TPAMI2023] Vision-Language Transformer and Query Generation for Referring Segmentation
MIT License
340 stars 21 forks source link

Corresponding code for the Query Generation Module #10

Open KevinGoodman opened 2 years ago

KevinGoodman commented 2 years ago

image Thanks for sharing the code. However, I'm quite confused for the code of QGM as the naming of the code is a little different from the original paper(if I understand it correctly...)

I think the code for that module is defined in function lang_tf_enc of model/transformer_model.py

def lang_tf_enc(vision_input,
                lang_input,
                head_num=8,
                hidden_dim=256):
    decoder_embed_lang = TrigPosEmbedding(
        mode=TrigPosEmbedding.MODE_ADD,
        name='Fusion-Lang-Decoder-Embedding',
    )(lang_input)
    decoder_embed_vis = TrigPosEmbedding(
        mode=TrigPosEmbedding.MODE_ADD,
        name='Fusion-Vis-Decoder-Embedding',
    )(vision_input)
    q_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_vis)
    k_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_lang)
    v_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_lang)
    decoded_layer = MultiHeadAttention(head_num=head_num)(
        [q_inp, k_inp, v_inp])
    add_layer = L.Add(name='Fusion-Add')([decoded_layer, vision_input])

    return add_layer

As the figure 4 suggests, the input vision features should be the raw vision features extracted from the vision backbone network. Yet the input for this function is features fused by vision & language features Fm_query(in function make_multitask_braches of model/vlt_model.py):

def make_multitask_braches(Fv, fq, fq_word, config):
    # fq: bs, 1024
    # fq_word: bs, 15, 1024
    Fm = simple_fusion(Fv[0], fq, config.jemb_dim)  # 13, 13, 1024

    Fm_mid_query = up_proj_cat_proj(Fm, Fv[1], K.int_shape(Fv[1],)[-1], K.int_shape(Fm)[-1]//2)  # 26, 26, 512
    Fm_query = pool_proj_cat_proj(Fm_mid_query, Fv[2], K.int_shape(Fv[2])[-1], K.int_shape(Fm)[-1]//2)  # 26, 26, 512

    Fm_mid_tf = proj_cat(Fm_query, Fm_mid_query, K.int_shape(Fm)[-1]//2)  # 26, 26, 1024
    F_tf = up_proj_cat_proj(Fm, Fm_mid_tf, K.int_shape(Fm)[-1] // 2)

    F_tf = V.DarknetConv2D_BN_Leaky(config.hidden_dim, (1, 1))(F_tf)

    # Fm_query:  bs, Hm, Wm, C  (None, 26, 26, 512)
    # Fm_top_tf :  bs, Hc, Wc, C  (None, 26, 26, 512)
    query_out = vlt_querynet(Fm_query, config)
    mask_out = vlt_transformer(F_tf, fq_word, query_out, config)
    mask_out = vlt_postproc(mask_out, Fm_query, config)

    return mask_out

Can you tell me if I got it wrong? Thanks for your great patience.