airsplay / lxmert

PyTorch code for EMNLP 2019 paper "LXMERT: Learning Cross-Modality Encoder Representations from Transformers".
MIT License
926 stars 158 forks source link

Why is object classification loss multiplied with the Faster R-CNN confidence score? #19

Closed j-min closed 4 years ago

j-min commented 4 years ago

During training, mask_conf is multiplied to feature regression and object classification loss, which are defined here. It is reasonable to mask feature regression loss on masked regions, but I don't understand the reason of multiplying the Faster R-CNN confidence score (top object probability) to object classification loss (which is cross-entropy loss). Is this sort of knowledge distillation? This is not mentioned in the EMNLP paper.

airsplay commented 4 years ago

Instead of arguing for some ”look-like" reason, I must say that it is a pratical consideration when I wrote it :->.

It is designed to stop over-fitting. I observed that the pre-training process is easy to over-fit the image features/labels: the training loss keeps decreasing while the validation loss (of obj feats / labels) will increase after 3 epochs. I thus multiply by this confidence: it's OK to overfit, but please overfit something more correct! A side effect is that the RoI-feature regression loss is no-longer overfitted when it is multiplied by this confidence.

I personally think that it would be better to use KL divergence between the detected labels's confidence and predicted probabilities (i.e., distillation) as the loss. For now, the code takes the term with largest confidence score. I currently do not have an answer with support from experiments.

j-min commented 4 years ago

Thanks for clarification! I'm also going to EMNLP. Hopefully we can talk some more about this work in person soon :)

airsplay commented 4 years ago

Willing to talk; see you in Hong Kong :)!

By the way, I just think of a counter-intuitive finding w.r.t. the visual pre-training losses.

When I realize the pre-training overfits the visual losses, the first thing I did is to increase the mask rate of objects. Intuitively, a higher mask rate makes the vision tasks harder, hence the over-fitting should be somehow relieved.

However, the val loss starts to increase (the bad direction) after 1 epoch (instead of 3 epochs) when the mask rate is increased...

A possible explanation (provided by Jie Lei) is that the higher mask rate increases the amount of "supervision" (more labels in detected-label cls and feat reg) per batch... I accept this explanation but I think that the pre-training of the visual branch might need a fundamental improvement (still not happen yet).