zhoubenjia / GFSLT-VLP

MIT License
43 stars 7 forks source link

Loading mBART weight problem #11

Open EyjafjalIa opened 7 months ago

EyjafjalIa commented 7 months ago

Hello! Thank you for opening source code! When I follow your finetune instruction, I got some missing keys of visual encoder in log file below. It seems like I load mbart as decoder unsuccessfully, but I download weights from Baidu Netdisk you given in pretrained_models

***********************************
Load parameters for Visual Encoder...
***********************************
Missing keys: 
mbart.final_logits_bias
mbart.model.shared.weight
mbart.model.decoder.layers.0.self_attn.k_proj.weight
mbart.model.decoder.layers.0.self_attn.k_proj.bias
mbart.model.decoder.layers.0.self_attn.v_proj.weight
mbart.model.decoder.layers.0.self_attn.v_proj.bias
mbart.model.decoder.layers.0.self_attn.q_proj.weight
mbart.model.decoder.layers.0.self_attn.q_proj.bias
mbart.model.decoder.layers.0.self_attn.out_proj.weight
mbart.model.decoder.layers.0.self_attn.out_proj.bias
mbart.model.decoder.layers.0.self_attn_layer_norm.weight
mbart.model.decoder.layers.0.self_attn_layer_norm.bias
mbart.model.decoder.layers.0.encoder_attn.k_proj.weight
mbart.model.decoder.layers.0.encoder_attn.k_proj.bias
mbart.model.decoder.layers.0.encoder_attn.v_proj.weight
mbart.model.decoder.layers.0.encoder_attn.v_proj.bias
mbart.model.decoder.layers.0.encoder_attn.q_proj.weight
mbart.model.decoder.layers.0.encoder_attn.q_proj.bias
mbart.model.decoder.layers.0.encoder_attn.out_proj.weight
mbart.model.decoder.layers.0.encoder_attn.out_proj.bias
mbart.model.decoder.layers.0.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.0.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.0.fc1.weight
mbart.model.decoder.layers.0.fc1.bias
mbart.model.decoder.layers.0.fc2.weight
mbart.model.decoder.layers.0.fc2.bias
mbart.model.decoder.layers.0.final_layer_norm.weight
mbart.model.decoder.layers.0.final_layer_norm.bias
mbart.model.decoder.layers.1.self_attn.k_proj.weight
mbart.model.decoder.layers.1.self_attn.k_proj.bias
mbart.model.decoder.layers.1.self_attn.v_proj.weight
mbart.model.decoder.layers.1.self_attn.v_proj.bias
mbart.model.decoder.layers.1.self_attn.q_proj.weight
mbart.model.decoder.layers.1.self_attn.q_proj.bias
mbart.model.decoder.layers.1.self_attn.out_proj.weight
mbart.model.decoder.layers.1.self_attn.out_proj.bias
mbart.model.decoder.layers.1.self_attn_layer_norm.weight
mbart.model.decoder.layers.1.self_attn_layer_norm.bias
mbart.model.decoder.layers.1.encoder_attn.k_proj.weight
mbart.model.decoder.layers.1.encoder_attn.k_proj.bias
mbart.model.decoder.layers.1.encoder_attn.v_proj.weight
mbart.model.decoder.layers.1.encoder_attn.v_proj.bias
mbart.model.decoder.layers.1.encoder_attn.q_proj.weight
mbart.model.decoder.layers.1.encoder_attn.q_proj.bias
mbart.model.decoder.layers.1.encoder_attn.out_proj.weight
mbart.model.decoder.layers.1.encoder_attn.out_proj.bias
mbart.model.decoder.layers.1.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.1.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.1.fc1.weight
mbart.model.decoder.layers.1.fc1.bias
mbart.model.decoder.layers.1.fc2.weight
mbart.model.decoder.layers.1.fc2.bias
mbart.model.decoder.layers.1.final_layer_norm.weight
mbart.model.decoder.layers.1.final_layer_norm.bias
mbart.model.decoder.layers.2.self_attn.k_proj.weight
mbart.model.decoder.layers.2.self_attn.k_proj.bias
mbart.model.decoder.layers.2.self_attn.v_proj.weight
mbart.model.decoder.layers.2.self_attn.v_proj.bias
mbart.model.decoder.layers.2.self_attn.q_proj.weight
mbart.model.decoder.layers.2.self_attn.q_proj.bias
mbart.model.decoder.layers.2.self_attn.out_proj.weight
mbart.model.decoder.layers.2.self_attn.out_proj.bias
mbart.model.decoder.layers.2.self_attn_layer_norm.weight
mbart.model.decoder.layers.2.self_attn_layer_norm.bias
mbart.model.decoder.layers.2.encoder_attn.k_proj.weight
mbart.model.decoder.layers.2.encoder_attn.k_proj.bias
mbart.model.decoder.layers.2.encoder_attn.v_proj.weight
mbart.model.decoder.layers.2.encoder_attn.v_proj.bias
mbart.model.decoder.layers.2.encoder_attn.q_proj.weight
mbart.model.decoder.layers.2.encoder_attn.q_proj.bias
mbart.model.decoder.layers.2.encoder_attn.out_proj.weight
mbart.model.decoder.layers.2.encoder_attn.out_proj.bias
mbart.model.decoder.layers.2.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.2.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.2.fc1.weight
mbart.model.decoder.layers.2.fc1.bias
mbart.model.decoder.layers.2.fc2.weight
mbart.model.decoder.layers.2.fc2.bias
mbart.model.decoder.layers.2.final_layer_norm.weight
mbart.model.decoder.layers.2.final_layer_norm.bias
mbart.model.decoder.layernorm_embedding.weight
mbart.model.decoder.layernorm_embedding.bias
mbart.model.decoder.layer_norm.weight
mbart.model.decoder.layer_norm.bias
mbart.lm_head.weight
sign_emb.src_emb.weight
sign_emb.src_emb.bias
sign_emb.bn_ac.0.weight
sign_emb.bn_ac.0.bias
sign_emb.bn_ac.0.running_mean
sign_emb.bn_ac.0.running_var
Unexpected keys: 

number of params: 115.412544M
zhoubenjia commented 6 months ago

Hello! Thank you for opening source code! When I follow your finetune instruction, I got some missing keys of visual encoder in log file below. It seems like I load mbart as decoder unsuccessfully, but I download weights from Baidu Netdisk you given in pretrained_models

***********************************
Load parameters for Visual Encoder...
***********************************
Missing keys: 
mbart.final_logits_bias
mbart.model.shared.weight
mbart.model.decoder.layers.0.self_attn.k_proj.weight
mbart.model.decoder.layers.0.self_attn.k_proj.bias
mbart.model.decoder.layers.0.self_attn.v_proj.weight
mbart.model.decoder.layers.0.self_attn.v_proj.bias
mbart.model.decoder.layers.0.self_attn.q_proj.weight
mbart.model.decoder.layers.0.self_attn.q_proj.bias
mbart.model.decoder.layers.0.self_attn.out_proj.weight
mbart.model.decoder.layers.0.self_attn.out_proj.bias
mbart.model.decoder.layers.0.self_attn_layer_norm.weight
mbart.model.decoder.layers.0.self_attn_layer_norm.bias
mbart.model.decoder.layers.0.encoder_attn.k_proj.weight
mbart.model.decoder.layers.0.encoder_attn.k_proj.bias
mbart.model.decoder.layers.0.encoder_attn.v_proj.weight
mbart.model.decoder.layers.0.encoder_attn.v_proj.bias
mbart.model.decoder.layers.0.encoder_attn.q_proj.weight
mbart.model.decoder.layers.0.encoder_attn.q_proj.bias
mbart.model.decoder.layers.0.encoder_attn.out_proj.weight
mbart.model.decoder.layers.0.encoder_attn.out_proj.bias
mbart.model.decoder.layers.0.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.0.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.0.fc1.weight
mbart.model.decoder.layers.0.fc1.bias
mbart.model.decoder.layers.0.fc2.weight
mbart.model.decoder.layers.0.fc2.bias
mbart.model.decoder.layers.0.final_layer_norm.weight
mbart.model.decoder.layers.0.final_layer_norm.bias
mbart.model.decoder.layers.1.self_attn.k_proj.weight
mbart.model.decoder.layers.1.self_attn.k_proj.bias
mbart.model.decoder.layers.1.self_attn.v_proj.weight
mbart.model.decoder.layers.1.self_attn.v_proj.bias
mbart.model.decoder.layers.1.self_attn.q_proj.weight
mbart.model.decoder.layers.1.self_attn.q_proj.bias
mbart.model.decoder.layers.1.self_attn.out_proj.weight
mbart.model.decoder.layers.1.self_attn.out_proj.bias
mbart.model.decoder.layers.1.self_attn_layer_norm.weight
mbart.model.decoder.layers.1.self_attn_layer_norm.bias
mbart.model.decoder.layers.1.encoder_attn.k_proj.weight
mbart.model.decoder.layers.1.encoder_attn.k_proj.bias
mbart.model.decoder.layers.1.encoder_attn.v_proj.weight
mbart.model.decoder.layers.1.encoder_attn.v_proj.bias
mbart.model.decoder.layers.1.encoder_attn.q_proj.weight
mbart.model.decoder.layers.1.encoder_attn.q_proj.bias
mbart.model.decoder.layers.1.encoder_attn.out_proj.weight
mbart.model.decoder.layers.1.encoder_attn.out_proj.bias
mbart.model.decoder.layers.1.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.1.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.1.fc1.weight
mbart.model.decoder.layers.1.fc1.bias
mbart.model.decoder.layers.1.fc2.weight
mbart.model.decoder.layers.1.fc2.bias
mbart.model.decoder.layers.1.final_layer_norm.weight
mbart.model.decoder.layers.1.final_layer_norm.bias
mbart.model.decoder.layers.2.self_attn.k_proj.weight
mbart.model.decoder.layers.2.self_attn.k_proj.bias
mbart.model.decoder.layers.2.self_attn.v_proj.weight
mbart.model.decoder.layers.2.self_attn.v_proj.bias
mbart.model.decoder.layers.2.self_attn.q_proj.weight
mbart.model.decoder.layers.2.self_attn.q_proj.bias
mbart.model.decoder.layers.2.self_attn.out_proj.weight
mbart.model.decoder.layers.2.self_attn.out_proj.bias
mbart.model.decoder.layers.2.self_attn_layer_norm.weight
mbart.model.decoder.layers.2.self_attn_layer_norm.bias
mbart.model.decoder.layers.2.encoder_attn.k_proj.weight
mbart.model.decoder.layers.2.encoder_attn.k_proj.bias
mbart.model.decoder.layers.2.encoder_attn.v_proj.weight
mbart.model.decoder.layers.2.encoder_attn.v_proj.bias
mbart.model.decoder.layers.2.encoder_attn.q_proj.weight
mbart.model.decoder.layers.2.encoder_attn.q_proj.bias
mbart.model.decoder.layers.2.encoder_attn.out_proj.weight
mbart.model.decoder.layers.2.encoder_attn.out_proj.bias
mbart.model.decoder.layers.2.encoder_attn_layer_norm.weight
mbart.model.decoder.layers.2.encoder_attn_layer_norm.bias
mbart.model.decoder.layers.2.fc1.weight
mbart.model.decoder.layers.2.fc1.bias
mbart.model.decoder.layers.2.fc2.weight
mbart.model.decoder.layers.2.fc2.bias
mbart.model.decoder.layers.2.final_layer_norm.weight
mbart.model.decoder.layers.2.final_layer_norm.bias
mbart.model.decoder.layernorm_embedding.weight
mbart.model.decoder.layernorm_embedding.bias
mbart.model.decoder.layer_norm.weight
mbart.model.decoder.layer_norm.bias
mbart.lm_head.weight
sign_emb.src_emb.weight
sign_emb.src_emb.bias
sign_emb.bn_ac.0.weight
sign_emb.bn_ac.0.bias
sign_emb.bn_ac.0.running_mean
sign_emb.bn_ac.0.running_var
Unexpected keys: 

number of params: 115.412544M

Hi, thanks your attention. In stage 2 (train_slt.py), it is imperative to load the pre-trained model obtained from stage 1 (train_vlp.py). It seems that you did not do the pre-training stage.

EyjafjalIa commented 6 months ago

I have run train_vlp.py and I also run train_vlp_v2.py. Below is my file structure:

out/
├──Gloss-Free/
│  |── best_checkpoint.pth
│  |── checkpoint.pth
│  |── log.txt
│  |── tmp_pres.txt
│  |── tmp_refs.txt
├──vlp/
│  |── best_checkpoint.pth
│  |── checkpoint.pth
│  |── log.txt
├──vlp_v2/
│  |── best_checkpoint.pth
│  |── checkpoint.pth
│  |── log.txt

checkpoint.pth in vlp folder is 1.6GB. It seems like train_slt.py line 231 loading my missing keys. After loading checkpoint.pth, I print all dict keys and I don't find any key has "text_encoder".

But When I change weight file to vlp_v2/checkpoint.pth , I load weights successfully.

I read train_vlp.py and train_vlp_v2.py and I found Text_Decoder in train_vlp_v2.py line 235 and not found Text_Decoder in train_vlp.py. So I think missing keys in finetune vlp is normal. Am I right? I'm looking forward to your reply!

zhoubenjia commented 6 months ago

Yes, you are right!

EyjafjalIa commented 6 months ago

I found another problem. I load vlp_v2/checkpoint.pth to finetune and now it run 100 epochs, however, BLEU-4 score is just 5.03. I have loaded vlp/checkpoint.pth and its final BLEU-4 score is close to the result in your paper. Why is the mBART parameter loaded but the result is poor? This is my finetune log of vlp_v2 on epoch 100:

Averaged stats: lr: 0.005079  loss: 2.8461 (2.7917)  lr_mbart: 0.0051 (0.0051)
Test:  [  0/104]  eta: 0:02:49  loss: 3.0136 (3.0136)  time: 1.6309  data: 1.4326  max mem: 8704
Test:  [ 10/104]  eta: 0:00:36  loss: 3.7486 (3.6883)  time: 0.3870  data: 0.1319  max mem: 8704
Test:  [ 20/104]  eta: 0:00:29  loss: 4.0569 (4.0858)  time: 0.2855  data: 0.0010  max mem: 8704
Test:  [ 30/104]  eta: 0:00:24  loss: 3.9377 (4.0084)  time: 0.2949  data: 0.0002  max mem: 8704
Test:  [ 40/104]  eta: 0:00:20  loss: 3.5864 (3.9765)  time: 0.2807  data: 0.0002  max mem: 8704
Test:  [ 50/104]  eta: 0:00:16  loss: 4.0049 (4.0007)  time: 0.2758  data: 0.0002  max mem: 8704
Test:  [ 60/104]  eta: 0:00:13  loss: 4.0909 (3.9866)  time: 0.2749  data: 0.0002  max mem: 8704
Test:  [ 70/104]  eta: 0:00:10  loss: 3.6714 (3.9354)  time: 0.2720  data: 0.0002  max mem: 8704
Test:  [ 80/104]  eta: 0:00:07  loss: 3.8170 (3.9715)  time: 0.2735  data: 0.0002  max mem: 8704
Test:  [ 90/104]  eta: 0:00:04  loss: 3.8170 (3.9626)  time: 0.2798  data: 0.0002  max mem: 8704
Test:  [100/104]  eta: 0:00:01  loss: 3.4312 (3.8635)  time: 0.2454  data: 0.0001  max mem: 8704
Test:  [103/104]  eta: 0:00:00  loss: 3.4312 (3.8457)  time: 0.2295  data: 0.0001  max mem: 8704
Test: Total time: 0:00:29 (0.2853 s / it)
* BELU-4 3.872 loss 3.897
BELU-4 of the network on the 104 dev videos: 3.87
Max BELU-4: 5.03%

This is my finetune log of vlp on epoch 100:

Averaged stats: lr: 0.005079  loss: 2.4409 (2.3801)  lr_mbart: 0.0051 (0.0051)
Test:  [  0/104]  eta: 0:02:49  loss: 2.8193 (2.8193)  time: 1.6304  data: 1.4442  max mem: 8704
Test:  [ 10/104]  eta: 0:00:36  loss: 3.4606 (3.4938)  time: 0.3874  data: 0.1315  max mem: 8704
Test:  [ 20/104]  eta: 0:00:28  loss: 3.6809 (3.7674)  time: 0.2727  data: 0.0002  max mem: 8704
Test:  [ 30/104]  eta: 0:00:22  loss: 3.5214 (3.7274)  time: 0.2670  data: 0.0002  max mem: 8704
Test:  [ 40/104]  eta: 0:00:19  loss: 3.5793 (3.7159)  time: 0.2625  data: 0.0002  max mem: 8704
Test:  [ 50/104]  eta: 0:00:15  loss: 3.7375 (3.7173)  time: 0.2601  data: 0.0002  max mem: 8704
Test:  [ 60/104]  eta: 0:00:12  loss: 3.6379 (3.7104)  time: 0.2700  data: 0.0002  max mem: 8704
Test:  [ 70/104]  eta: 0:00:09  loss: 3.3802 (3.6466)  time: 0.2726  data: 0.0002  max mem: 8704
Test:  [ 80/104]  eta: 0:00:06  loss: 3.6600 (3.6955)  time: 0.2654  data: 0.0002  max mem: 8704
Test:  [ 90/104]  eta: 0:00:03  loss: 3.8700 (3.6943)  time: 0.2728  data: 0.0002  max mem: 8704
Test:  [100/104]  eta: 0:00:01  loss: 2.9645 (3.5892)  time: 0.2398  data: 0.0001  max mem: 8704
Test:  [103/104]  eta: 0:00:00  loss: 2.8591 (3.5618)  time: 0.2295  data: 0.0001  max mem: 8704
Test: Total time: 0:00:28 (0.2755 s / it)
* BELU-4 17.531 loss 3.599
BELU-4 of the network on the 104 dev videos: 17.53
Max BELU-4: 17.99%