taoxugit / AttnGAN

MIT License
1.33k stars 415 forks source link

the meaning of the mask in "attn.data.masked_fill_(mask.data, -float('inf'))" in the forward function of the class GlobalAttentionGeneral #93

Open fyw1999 opened 2 years ago

fyw1999 commented 2 years ago

In the forward function of the class GlobalAttentionGeneral, there are some codes i think maybe wrong. We assume that batch_size is 20, words_num is 18 and embedding_dim is 256, so in the second stage of the generation, the dimensions of argument input and context are (20,48,64,64) and (20,256,18). After " attn = attn.view(batch_size * queryL, sourceL)", we can infer that the dimension of attn is (81920,18), meanwhile, we know the dimension of mask is equal to the dimension of captions is (20,18). If a value in the tensor captions is 0, it indicates that there is no word in that position in the sentence, and the same position in the tensor mask is 1. Based on the above analysis, i think the purpose of code "attn.data.maskedfill(mask.data.bool(), -float('inf'))" is setting the value in attn to minus infinity if these is no word in the corresponding position of captions. However, although the dimension of mask is the same as the dimension of attn after "mask = self.mask.repeat(queryL, 1)", the meaning of the position of value is not corresponding between mask and attn. Because the dimension of mask goes from (20,18) to (81920,18) and the dimension of attn goes from (20,4096,18) to (81920,18), which just repeat the dimension of mask along the row, in this case, mask[1][0] represents whether there is a word in the first position of the sencond sentence in a batch but attn[1][0] represents the dot product of the second pixel of the first image and the firsrt word of the first sentence in a batch. So the same postion in mask and attn represents different meaning. Can anyone answer my questions, thank you very much. `class GlobalAttentionGeneral(nn.Module): def init(self, idf, cdf): super(GlobalAttentionGeneral, self).init() self.conv_context = conv1x1(cdf, idf) self.sm = nn.Softmax() self.mask = None

def applyMask(self, mask):
    self.mask = mask  # batch x sourceL

def forward(self, input, context):
    """
        input: batch x idf x ih x iw (queryL=ihxiw)
        context: batch x cdf x sourceL
    """
    ih, iw = input.size(2), input.size(3)
    queryL = ih * iw
    batch_size, sourceL = context.size(0), context.size(2)

    # --> batch x queryL x idf
    target = input.view(batch_size, -1, queryL)
    targetT = torch.transpose(target, 1, 2).contiguous()
    # batch x cdf x sourceL --> batch x cdf x sourceL x 1
    sourceT = context.unsqueeze(3)
    # --> batch x idf x sourceL
    sourceT = self.conv_context(sourceT).squeeze(3)

    # Get attention
    # (batch x queryL x idf)(batch x idf x sourceL)
    # -->batch x queryL x sourceL
    attn = torch.bmm(targetT, sourceT)
    # --> batch*queryL x sourceL
    attn = attn.view(batch_size*queryL, sourceL)
    if self.mask is not None:
        # batch_size x sourceL --> batch_size*queryL x sourceL
        mask = self.mask.repeat(queryL, 1)
        attn.data.masked_fill_(mask.data, -float('inf'))
    attn = self.sm(attn)  # Eq. (2)
    # --> batch x queryL x sourceL
    attn = attn.view(batch_size, queryL, ### sourceL)
    # --> batch x sourceL x queryL
    attn = torch.transpose(attn, 1, 2).contiguous()

    # (batch x idf x sourceL)(batch x sourceL x queryL)
    # --> batch x idf x queryL
    weightedContext = torch.bmm(sourceT, attn)
    weightedContext = weightedContext.view(batch_size, -1, ih, iw)
    attn = attn.view(batch_size, -1, ih, iw)

    return weightedContext, attn`
baolp commented 2 years ago

I have the same question.

YibinLiu666 commented 1 year ago

I have the same question. mask is used to eliminate the influence of images with the same label, otherwise the distance between text and images with the same label will be expanded in contrastive learning