nickgkan / butd_detr

Code for the ECCV22 paper "Bottom Up Top Down Detection Transformers for Language Grounding in Images and Point Clouds"
Other
74 stars 11 forks source link

About the implementation of the loss function #16

Closed Hiusam closed 1 year ago

Hiusam commented 1 year ago

Hi, thanks for your great work. I have some questions regarding the implementation of the loss function. Hope you can give me some hints.

  1. in loss_contrastive_align:

        # construct a map such that positive_map[k, i, j] = True
        # iff query i is associated to token j in batch item k
        positive_map = torch.zeros(logits.shape, device=logits.device)
    
        # handle 'not mentioned' # ? these two correspond to the last two tokens, which are " mentioned" and "</s>"?
        inds = tokenized['attention_mask'].sum(1) - 1
        positive_map[torch.arange(len(inds)), :, inds] = 0.5
        positive_map[torch.arange(len(inds)), :, inds - 1] = 0.5 

    I think the last two lines set the unmatched query to correspond to "not mentioned", is that right? But inds and inds-1 are indexes of "< /s >" and "mentioned". I think there should be inds - 1 and inds - 2.

    Here is my debugging output:

        > tokenized['attention_mask'][0].sum() - 1 
        tensor(12, device='cuda:0') 
        > self.tokenizer.decode(tokenized['input_ids'][0][12]) 
        '</s>' 
        > self.tokenizer.decode(tokenized['input_ids'][0][11]) 
        ' mentioned' 
  2. in loss_contrastive_align:

        # Token mask for matches <> 'not mentioned' 
        tmask = torch.full(
            (len(logits), logits.shape[-1]),
            self.eos_coef,
            dtype=torch.float32, device=logits.device
        ) # * (B, max_token)
        tmask[torch.arange(len(inds)), inds] = 1.0 

    Why do set the weight of the last token to 1.0? I think we should set those tokens with matched queries to 1.0

  3. in loss_contrastive_align:

        # Loss 1: which tokens should each query match?
        boxes_with_pos = positive_map.any(2)
        pos_term = positive_logits.sum(2)
        neg_term = negative_logits.logsumexp(2)
        nb_pos = positive_map.sum(2) + 1e-6 
        entropy = -torch.log(nb_pos+1e-6) / nb_pos  # entropy of 1/nb_pos
        box_to_token_loss_ = (
            (entropy + pos_term / nb_pos + neg_term)
        ).masked_fill(~boxes_with_pos, 0)
        box_to_token_loss = (box_to_token_loss_ * mask).sum()

    Why do we directly add (entropy + pos_term / nb_pos + neg_term). Should we take sum(exp(positive_logits)) to get pos_term and then pos_term / neg_term?

  4. in loss_labels_st:

         entropy = torch.log(target_sim + 1e-6) * target_sim 
         loss_ce = (entropy - logits * target_sim).sum(-1)

    What is entropy used for? Its elements seem to be near zero.

nickgkan commented 1 year ago

Hi,

  1. It seems you're correct. However, the exact text you use to attract the negative boxes is not really important, as long as a) it's consistent and b) it's not really related to the rest of the sentence (e.g., you would not be able to use the name of a class as your negatives' attraction point).

  2. We found that this choice, in combination with other losses, performs reasonably. It may be suboptimal, as we did not have the time and resources to extensively ablate on such hyperparameters.

  3. It will become clearer to understand this if you see equation 1 in MDETR. The loss is -log(pos/neg), pos = exp(pos_logits), neg = Sum(exp(logits)). So this can be written as -log(pos) + log(neg) = -pos_logits + log_sum_exp(logits). And pos_logits are already negative logits (in a line not shown in what you pasted).

  4. This is the entropy of the target distribution so that the computed loss becomes the kl-divergence. Gradients are not propagated through this term, it's there to ensure that the range of the loss is the same across batches. MDETR does not use this term.

Hiusam commented 1 year ago

Thank you very much for your reply!