FangyunWei / SLRT

259 stars 56 forks source link

Ealuation of the provided SLT checkpoints on the phoenix2014-T dataset #39

Closed hulianyuyy closed 11 months ago

hulianyuyy commented 11 months ago

Thanks for sharing the codebase, and I want to evaluate the provided checkpoints in this repo.

When I evaluate the checkpoints on the phoenix2014-T dataset, it reports the weights size mismatch error, as shown below:

error

I try to locate the origin of this error, based on the prune_embedding.ipynb, the final classifier should be torch.Size([2498, 1024]) and the gloss_embeddings should be 1124. However, the load weights are ([2473, 1024]) and 1120? How to fix this error?

Many thanks for your efforts and looking forward to your reply~

2000ZRL commented 11 months ago

It seems that a similar issue was resolved in https://github.com/FangyunWei/SLRT/issues/23

hulianyuyy commented 11 months ago

Yes, while it seems that this issue is only fixed for CSL-DAILY dataset, and still exists on the PHOENIX14-T dataset. This issue seems to be related with the pruning file. 

---Original--- From: "Ronglai @.> Date: Mon, Dec 4, 2023 13:58 PM To: @.>; Cc: @.**@.>; Subject: Re: [FangyunWei/SLRT] Ealuation of the provided SLT checkpoints onthe phoenix2014-T dataset (Issue #39)

It seems that a similar issue was resolved in #23

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

2000ZRL commented 11 months ago

Could you please try the new gloss_embeddings.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/ET3AGVeUfKtHnXdWQ2NXn3EB2cj-NKllq9cTU82SXHjTWQ?e=MP7taU

hulianyuyy commented 11 months ago

Could you please try the new gloss_embeddings.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/ET3AGVeUfKtHnXdWQ2NXn3EB2cj-NKllq9cTU82SXHjTWQ?e=MP7taU

Many thanks for your advice! This could be successfully loaded for evaluation. Howevere, this issue was triggered when i loaded the checkpoints from phoenix-2014t_g2t, whose classifier and gloss_embeddings seem to be incompatible with the current model built from this repo. look forward to your reply~

2000ZRL commented 11 months ago

Could you please provide your command to run the code?

hulianyuyy commented 11 months ago

I use the command 'python -m torch.distributed.launch --nproc_per_node 1 --use_env training.py --config experiments/configs/SingleStream/phoenix-2014t_s2t.yaml, and set the 'load_ckpt' for TranslationNetwork as the path to the downloaded dir of this checkpoint.

2000ZRL commented 11 months ago

Could you please try the latest pytorch_model.bin in https://hkustconnect-my.sharepoint.com/:u:/g/personal/rzuo_connect_ust_hk/EVDV6bZgi6dCu3eGKLogTuQBYDL_WMSLHU4E6Bu0QbzigA?e=HcqoYO It should be placed in pretrained_models/mBart_de. Its weight shape should be the same as the model checkpoint.

hulianyuyy commented 11 months ago

Many thanks for your reply! This could work well. Could i ask which dataset is this model pretrained on? CC25 or Gloss2Text?

2000ZRL commented 11 months ago

This model is pretrained on CC25. The overall pretraining strategy can be refer to the "progressive pretraining" proposed in https://openaccess.thecvf.com/content/CVPR2022/papers/Chen_A_Simple_Multi-Modality_Transfer_Learning_Baseline_for_Sign_Language_Translation_CVPR_2022_paper.pdf.

hulianyuyy commented 11 months ago

Now i want to use the pretrained model for g2t and s2g to conduct end-to-end training. But i found that loading the checkpoint for g2t would encounter the issue with mismatched shape for classifier and gloss_embeddings. How could i fix this?

2000ZRL commented 11 months ago

Do you use the new prune_embeddings and pytorch_model?

2000ZRL commented 11 months ago

Do you use the new prune_embeddings and pytorch_model?

Altough the g2t checkpoint will cover the weights in pytorch_model.bin, it is essential to determine the model weight shape. So, please first ensure that you download them and place them correctly.

hulianyuyy commented 11 months ago

Mank thanks for you advice. I found that besides pytorch_model.bin and gloss_embeddings.bin, the num_classes should also be changed into 2473, which is 2498 automatically generated by the pruning file. I finally successfullu run the code with 'load_ckpt' of g2t checkpoint, and the lastest pytorch_model.bin and gloss_embeddings.bin, by redownloading the files from this dir.