ashkamath / mdetr

Apache License 2.0
969 stars 125 forks source link

Contrastive loss implementation discrepancy between the paper and codebase #8

Closed Gyat closed 3 years ago

Gyat commented 3 years ago

Hello,

This is in relation to the losses described in the paper and implemented in the codebase. Need your help in understanding the following:

  1. The 4th Page in the paper reads that: "the contrastive alignment loss enforces alignment between the embedded representations of the object at the output of the decoder, and the text representation at the output of the cross encoder." However, in the code transformer.py, the following snippet is being used for the loss calculations:

"text_pooled_op": encoded_text.pooler_output if self.CLS is not None else None,

"img_pooled_op": img_memory[0] if self.CLS is not None else None, # Return the CLS token

which essentially means that we are deriving the embedded representation of the text from the BERT-based text backbone encoder's classification token and the image embedded representation is being derived from the output of the transformer encoder. Is this genuinely a discrepancy? If not, can you kindly point towards the snippet for these loss calculations where you are tapping in the decoder output?

  1. Also, is the following understanding correct: The 'Soft token prediction' loss from the paper is actually called 'contrastive_align_loss' in the codebase and the 'Contrastive alignment' loss from the paper is actually named 'contrastive_loss' in the codebase.

Thank you.

ashkamath commented 3 years ago

Hi, It looks like you're confusing the contrastive_align_loss with the contrastive_loss. In our paper and published results, we do not use the contrastive loss (which is akin to an image-text matching loss from other vision+language pre-training papers). We only left it in the code for completeness since it is something we tried at some point, and thought it would be useful if other users of our code base were interested in experimenting with it. For the two losses that we do use, read the following:

  1. Contrastive align loss, which is calculated between the predictions of the decoder and the embedded representations of the text and the output of the cross encoder. Relevant lines in the code: https://github.com/ashkamath/mdetr/blob/fdee8c50d7bcf2ad09cc0d6b783a8333720e4048/models/mdetr.py#L81 , https://github.com/ashkamath/mdetr/blob/fdee8c50d7bcf2ad09cc0d6b783a8333720e4048/models/mdetr.py#L203, https://github.com/ashkamath/mdetr/blob/fdee8c50d7bcf2ad09cc0d6b783a8333720e4048/models/mdetr.py#L496

  2. Contrastive alignment -> loss_contrastive_align that we just discussed above. Soft token prediction is loss_labels https://github.com/ashkamath/mdetr/blob/fdee8c50d7bcf2ad09cc0d6b783a8333720e4048/models/mdetr.py#L464

Hope this makes it more clear! :)