airsplay / lxmert

PyTorch code for EMNLP 2019 paper "LXMERT: Learning Cross-Modality Encoder Representations from Transformers".
MIT License
935 stars 158 forks source link

Why does object and attribute loss not be masked? #41

Open forjiuzhou opened 4 years ago

forjiuzhou commented 4 years ago

In lxmert_pretrain.py obj_labels={ 'obj': (obj_labels, obj_confs), 'attr': (attr_labels, attr_confs), 'feat': (feat, feat_mask), }, It seems that obj and attr loss is computed on all objects, not masked by feat_mask. It's odd to predict object class with its feature unmasked. Am I correct? Also mentioned in here #19 Thanks in advance.

airsplay commented 4 years ago

Thanks. The loss is masked here: https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/lxrt/modeling.py#L972

forjiuzhou commented 4 years ago

Thanks. The loss is masked here:

https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/lxrt/modeling.py#L972

Thanks for reply, but it still confuses me. In the code I mentioned above, 'mask_conf' is actually 'obj_confs', which is predicted score of faster rcnn, not the mask made in 'convert_example_to_features' function. I put relevent codes here.

These codes show in obj and attr config, the second variable of 'obj_labels' is 'obj_confs' https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L174-L175 https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L177-L178 https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L200-L204

And this shows 'mask_conf' is second variable of 'obj_labels' https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/lxrt/modeling.py#L962 https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/lxrt/modeling.py#L972

Am I missing something here?

airsplay commented 4 years ago

Thanks a lot!! I think that it is a real bug in my code which I did not notice before. It might also be the reason why the obj-loss and attr-loss are overfitted but the feat-loss is not.

I am going to add two lines

obj_confs *= feat_mask
attr_confs *= feat_mask

after the masking of visual objects here https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L178 and retrain the pre-training model to see whether it would change the results on down-stream tasks.

Would you think that this modification would completely fix the masking issue?

Best, Hao

forjiuzhou commented 4 years ago

That should fix it.

As for 'obj_confs' part, I would be interested if the model is pretained direct with mask. Cause 'obj_confs' is somehow a distillation approach.

And it's also interesting that the previous lxmert actually failed in visual MLM training, but lxmert still got remarkble result in downstream task. Could that be a implication that visual MLM training doesn't make significant contribution in the whole training process? I'm very interested to this modified version of lxmert. Look forward to the result!

Thanks for your great work! Best regards

airsplay commented 4 years ago

Agreed. I am thinking of using

obj_confs, attr_confs = feat_mask, feat_mask

instead to provide zero-one masking.

The final effect of previous "bug" is a shift in the ratio of different losses thus is actually not that bad, mathematically. I will explain it accordingly.

The original mlm loss (in the BERT paper) could be decomposed into three parts

LOSS = 0.15 (masking_rate) * (0.8 * CONTEXT_LOSS + 0.1 * CORRECTION_LOSS + 0.1 * RECONSTRUCTION_LOSS)

where the CONTEXT_LOSS is to predict the missing word from the context: https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L89 the CORRECTION_LOSS tries to correct the wrong words: https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L93 and the RECONSTRUTION_LOSS faithfully reconstructs the input: https://github.com/airsplay/lxmert/blob/827087bbf7b8eb4cb2b648aeb1875ded79bf4ff0/src/pretrain/lxmert_pretrain.py#L95

The detailed explanation of the above implementation could be found in Sec 3.1 of BERT paper as well.

The "bug" actually change the ratio of the loss with

LOSS = 0.15 (masking_rate) * (0.8 * CONTEXT_LOSS + 0.1 * CORRECTION_LOSS + 0.1 * RECONSTRUCTION_LOSS) 
       + 0.85 (non-masking-rate) RECONSTRUCTION_LOSS

which thus has a high ratio on the RECONSTRUCTION_LOSS and would be easy to overfit. But it seems not to completely stop the effect of the mlm loss which is supposed to mainly come from the CONTEXT_LOSS.

I hope that it is correct U_U!

Best

airsplay commented 4 years ago

An update on this issue:

I finally got enough computational resources (2 TitanV x 10 days) to reproduce the pre-training results with obj-mask. Fine-tuning results based on this new pre-training are here:

VQA:  69.60
GQA:  59.49
NLVR2:74.18

So the current conjecture is that the different losses (w/ or w/o calculating the non-mask objects) would not significantly affect the pre-training. A mathematical explanation is provided in the previous post and here is another evidence:

I tried to pre-train the pure-language BERT model with calculating losses for all tokens while the original BERT model only predicts losses for masked tokens. Somehow surprisingly, these two results are almost the same on the GLUE benchmark, i.e., we do observe the same phenomenon on BERT.

Overall, current facts show that these two losses do not have much difference. However, predicting masked tokens sounds more correct. Moreover, a deliberate implementation of the model might speed up the training process (gather + FC + loss) with the support of devices (e.g., GPU might support but TPU would not support this gather operator efficiently).

forjiuzhou commented 4 years ago

Could you share the logfile of the new training process, I'd like to see the loss change of training and validation set.

airsplay commented 4 years ago

Sure. Here is the log for the first 8 epochs. The obj/attr losses are still overfitted.

The training loss for Epoch 0 is 9.6472
The losses are Mask_LM: 2.9173 Matched: 0.4321 Obj: 1.5950 Attr: 1.2000 Feat: 0.3019 QA: 3.2009
Overall Accu 0.2024, gqa Accu 0.2581, visual7w Accu 0.1256, vqa Accu 0.2816,
The valid loss is 6.6441
The losses are Mask_LM: 1.8681 Matched: 0.2997 Obj: 0.8907 Attr: 0.8121 Feat: 0.2492 QA: 2.5245              Overall Accu 0.2393, gqa Accu 0.2895, visual7w Accu 0.1526, vqa Accu 0.3015,
100%|█████████████████████████████████████████████████████████████████| 35027/35027 [9:43:28<00:00,  1.00it/s]
The training loss for Epoch 1 is 6.0922
The losses are Mask_LM: 1.7541 Matched: 0.2948 Obj: 0.7496 Attr: 0.7089 Feat: 0.2444 QA: 2.3405
Overall Accu 0.2456, gqa Accu 0.3072, visual7w Accu 0.1657, vqa Accu 0.3214,
The valid loss is 5.9648
The losses are Mask_LM: 1.6523 Matched: 0.2734 Obj: 0.7969 Attr: 0.7358 Feat: 0.2366 QA: 2.2698              Overall Accu 0.2627, gqa Accu 0.3176, visual7w Accu 0.1691, vqa Accu 0.3293,
100%|█████████████████████████████████████████████████████████████████| 35027/35027 [9:38:00<00:00,  1.01it/s]
The training loss for Epoch 2 is 5.4681
The losses are Mask_LM: 1.6191 Matched: 0.2731 Obj: 0.6100 Attr: 0.6041 Feat: 0.2371 QA: 2.1248
Overall Accu 0.2628, gqa Accu 0.3329, visual7w Accu 0.1804, vqa Accu 0.3305,
The valid loss is 5.7303
The losses are Mask_LM: 1.5674 Matched: 0.2594 Obj: 0.7865 Attr: 0.7381 Feat: 0.2331 QA: 2.1457
Overall Accu 0.2636, gqa Accu 0.3220, visual7w Accu 0.1807, vqa Accu 0.3142,
 97%|███████████████████████████████████████████████████████████████▎ | 34127/35027 [9:32:51<14:47,  1.01it/s]
100%|█████████████████████████████████████████████████████████████████| 35027/35027 [9:47:50<00:00,  1.01s/it]
The training loss for Epoch 3 is 5.1096
The losses are Mask_LM: 1.5475 Matched: 0.2613 Obj: 0.5232 Attr: 0.5335 Feat: 0.2346 QA: 2.0094
Overall Accu 0.2707, gqa Accu 0.3437, visual7w Accu 0.1885, vqa Accu 0.3334,
The valid loss is 5.6082
The losses are Mask_LM: 1.4951 Matched: 0.2505 Obj: 0.7983 Attr: 0.7489 Feat: 0.2318 QA: 2.0836
Overall Accu 0.2779, gqa Accu 0.3420, visual7w Accu 0.1827, vqa Accu 0.3384,
100%|████████████████████████████████████████████████████████████████| 35027/35027 [10:07:12<00:00,  1.04s/it]
The training loss for Epoch 4 is 4.8619
The losses are Mask_LM: 1.4989 Matched: 0.2530 Obj: 0.4587 Attr: 0.4763 Feat: 0.2335 QA: 1.9416
Overall Accu 0.2775, gqa Accu 0.3527, visual7w Accu 0.1943, vqa Accu 0.3384,
The valid loss is 5.6213
The losses are Mask_LM: 1.4687 Matched: 0.2456 Obj: 0.8266 Attr: 0.7773 Feat: 0.2313 QA: 2.0718
Overall Accu 0.2871, gqa Accu 0.3569, visual7w Accu 0.1862, vqa Accu 0.3498,
100%|████████████████████████████████████████████████████████████████| 35027/35027 [10:14:09<00:00,  1.05s/it]
The training loss for Epoch 5 is 4.6646
The losses are Mask_LM: 1.4607 Matched: 0.2466 Obj: 0.4087 Attr: 0.4281 Feat: 0.2327 QA: 1.8878
Overall Accu 0.2849, gqa Accu 0.3628, visual7w Accu 0.1994, vqa Accu 0.3466,
The valid loss is 5.6009
The losses are Mask_LM: 1.4296 Matched: 0.2419 Obj: 0.8473 Attr: 0.8040 Feat: 0.2307 QA: 2.0474
Overall Accu 0.2858, gqa Accu 0.3521, visual7w Accu 0.1911, vqa Accu 0.3441,
100%|████████████████████████████████████████████████████████████████| 35027/35027 [10:14:28<00:00,  1.05s/it]
The training loss for Epoch 6 is 4.5068
The losses are Mask_LM: 1.4322 Matched: 0.2415 Obj: 0.3687 Attr: 0.3881 Feat: 0.2321 QA: 1.8442
Overall Accu 0.2890, gqa Accu 0.3691, visual7w Accu 0.2023, vqa Accu 0.3495,
The valid loss is 5.6391
The losses are Mask_LM: 1.4113 Matched: 0.2412 Obj: 0.8661 Attr: 0.8254 Feat: 0.2307 QA: 2.0644
Overall Accu 0.2849, gqa Accu 0.3522, visual7w Accu 0.1905, vqa Accu 0.3417,
100%|█████████████████████████████████████████████████████████████████| 35027/35027 [9:41:02<00:00,  1.00it/s]
The training loss for Epoch 7 is 4.3599
The losses are Mask_LM: 1.4047 Matched: 0.2370 Obj: 0.3360 Attr: 0.3530 Feat: 0.2315 QA: 1.7976
Overall Accu 0.2939, gqa Accu 0.3760, visual7w Accu 0.2059, vqa Accu 0.3538,
The valid loss is 5.6127
The losses are Mask_LM: 1.3777 Matched: 0.2368 Obj: 0.8786 Attr: 0.8531 Feat: 0.2296 QA: 2.0369
Overall Accu 0.2851, gqa Accu 0.3516, visual7w Accu 0.1913, vqa Accu 0.3418,
100%|█████████████████████████████████████████████████████████████████| 35027/35027 [9:39:07<00:00,  1.01it/s]
zhmd commented 3 years ago

Thanks for the new results! It seems two fixes are discussed in this post, one is

obj_confs *= feat_mask
attr_confs *= feat_mask

and the other is

obj_confs, attr_confs = feat_mask, feat_mask

Could you clarify that the following results are for which scenario?

VQA: 69.60 GQA: 59.49 NLVR2:74.18

Many thanks!