hengyuan-hu / bottom-up-attention-vqa

An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.
GNU General Public License v3.0
750 stars 182 forks source link

about forward() #39

Open lilei132 opened 5 years ago

lilei132 commented 5 years ago

def forward(self, v, b, q, labels): """Forward

    v: [batch, num_objs, obj_dim]
    b: [batch, num_objs, b_dim]
    q: [batch_size, seq_length]

    return: logits, not probs
    """
    w_emb = self.w_emb(q)
    q_emb = self.q_emb(w_emb) # [batch, q_dim]

    att = self.v_att(v, q_emb)
    v_emb = (att * v).sum(1) # [batch, v_dim]

    q_repr = self.q_net(q_emb)
    v_repr = self.v_net(v_emb)
    joint_repr = q_repr * v_repr
    logits = self.classifier(joint_repr)
    return logits

When I read this code, I had some doubts that dataset has four features, spatials, question, target, whether v, b, q and a are one-to-one corresponding, and then the value of b is not used at all in forward (), so where is the function of b reflected?

wjl520 commented 5 years ago

Hello, about your quesiton: def forward(). For b and labels, it's not used in forward function. b is spatial features, labels is true answer and is use to calculate loss.