Closed Hiusam closed 1 year ago
Hi,
It seems you're correct. However, the exact text you use to attract the negative boxes is not really important, as long as a) it's consistent and b) it's not really related to the rest of the sentence (e.g., you would not be able to use the name of a class as your negatives' attraction point).
We found that this choice, in combination with other losses, performs reasonably. It may be suboptimal, as we did not have the time and resources to extensively ablate on such hyperparameters.
It will become clearer to understand this if you see equation 1 in MDETR. The loss is -log(pos/neg), pos = exp(pos_logits), neg = Sum(exp(logits)). So this can be written as -log(pos) + log(neg) = -pos_logits + log_sum_exp(logits). And pos_logits are already negative logits (in a line not shown in what you pasted).
This is the entropy of the target distribution so that the computed loss becomes the kl-divergence. Gradients are not propagated through this term, it's there to ensure that the range of the loss is the same across batches. MDETR does not use this term.
Thank you very much for your reply!
Hi, thanks for your great work. I have some questions regarding the implementation of the loss function. Hope you can give me some hints.
in
loss_contrastive_align
:I think the last two lines set the unmatched query to correspond to "not mentioned", is that right? But
inds
andinds-1
are indexes of "< /s >" and "mentioned". I think there should beinds - 1
andinds - 2
.Here is my debugging output:
in
loss_contrastive_align
:Why do set the weight of the last token to 1.0? I think we should set those tokens with matched queries to 1.0
in
loss_contrastive_align
:Why do we directly add
(entropy + pos_term / nb_pos + neg_term)
. Should we take sum(exp(positive_logits)) to get pos_term and then pos_term / neg_term?in
loss_labels_st
:What is entropy used for? Its elements seem to be near zero.