FangyunWei / SLRT

236 stars 46 forks source link

Ealuation of the provided SLT checkpoints on the phoenix2014-T dataset #39

Closed hulianyuyy closed 9 months ago

hulianyuyy commented 9 months ago

Thanks for sharing the codebase, and I want to evaluate the provided checkpoints in this repo.

When I evaluate the checkpoints on the phoenix2014-T dataset, it reports the weights size mismatch error, as shown below:

error

I try to locate the origin of this error, based on the prune_embedding.ipynb, the final classifier should be torch.Size([2498, 1024]) and the gloss_embeddings should be 1124. However, the load weights are ([2473, 1024]) and 1120? How to fix this error?

Many thanks for your efforts and looking forward to your reply~

2000ZRL commented 9 months ago

It seems that a similar issue was resolved in https://github.com/FangyunWei/SLRT/issues/23

hulianyuyy commented 9 months ago

Yes, while it seems that this issue is only fixed for CSL-DAILY dataset, and still exists on the PHOENIX14-T dataset. This issue seems to be related with the pruning file. 

---Original--- From: "Ronglai @.> Date: Mon, Dec 4, 2023 13:58 PM To: @.>; Cc: @.**@.>; Subject: Re: [FangyunWei/SLRT] Ealuation of the provided SLT checkpoints onthe phoenix2014-T dataset (Issue #39)

It seems that a similar issue was resolved in #23

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

2000ZRL commented 9 months ago

Could you please try the new gloss_embeddings.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/ET3AGVeUfKtHnXdWQ2NXn3EB2cj-NKllq9cTU82SXHjTWQ?e=MP7taU

hulianyuyy commented 9 months ago

Could you please try the new gloss_embeddings.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/ET3AGVeUfKtHnXdWQ2NXn3EB2cj-NKllq9cTU82SXHjTWQ?e=MP7taU

Many thanks for your advice! This could be successfully loaded for evaluation. Howevere, this issue was triggered when i loaded the checkpoints from phoenix-2014t_g2t, whose classifier and gloss_embeddings seem to be incompatible with the current model built from this repo. look forward to your reply~

2000ZRL commented 9 months ago

Could you please provide your command to run the code?

hulianyuyy commented 9 months ago

I use the command 'python -m torch.distributed.launch --nproc_per_node 1 --use_env training.py --config experiments/configs/SingleStream/phoenix-2014t_s2t.yaml, and set the 'load_ckpt' for TranslationNetwork as the path to the downloaded dir of this checkpoint.

2000ZRL commented 9 months ago

Could you please try the latest pytorch_model.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/EVDV6bZgi6dCu3eGKLogTuQBYDL_WMSLHU4E6Bu0QbzigA?e=HcqoYO It should be placed in pretrained_models/mBart_de. Its weight shape should be the same as the model checkpoint.

hulianyuyy commented 9 months ago

Many thanks for your reply! This could work well. Could i ask which dataset is this model pretrained on? CC25 or Gloss2Text?

2000ZRL commented 9 months ago

This model is pretrained on CC25. The overall pretraining strategy can be refer to the "progressive pretraining" proposed in https://openaccess.thecvf.com/content/CVPR2022/papers/Chen_A_Simple_Multi-Modality_Transfer_Learning_Baseline_for_Sign_Language_Translation_CVPR_2022_paper.pdf.

hulianyuyy commented 9 months ago

Now i want to use the pretrained model for g2t and s2g to conduct end-to-end training. But i found that loading the checkpoint for g2t would encounter the issue with mismatched shape for classifier and gloss_embeddings. How could i fix this?

2000ZRL commented 9 months ago

Do you use the new prune_embeddings and pytorch_model?

2000ZRL commented 9 months ago

Do you use the new prune_embeddings and pytorch_model?

Altough the g2t checkpoint will cover the weights in pytorch_model.bin, it is essential to determine the model weight shape. So, please first ensure that you download them and place them correctly.

hulianyuyy commented 9 months ago

Mank thanks for you advice. I found that besides pytorch_model.bin and gloss_embeddings.bin, the num_classes should also be changed into 2473, which is 2498 automatically generated by the pruning file. I finally successfullu run the code with 'load_ckpt' of g2t checkpoint, and the lastest pytorch_model.bin and gloss_embeddings.bin, by redownloading the files from this dir.