ShampooWang / SpeechCLIP_plus

SpeechCLIP+: Self-supervised multi-task representation learning for speech via CLIP and speech-image data. Accepted to ICASSP 2024, Self-supervision in Audio, Speech, and Beyond (SASB) workshop.
4 stars 0 forks source link

size mismatch for loading model #3

Open marcos452 opened 4 months ago

marcos452 commented 4 months ago

Thanks for your great work.

I am trying to load large model ./icassp_sasb_ckpts/SpeechCLIP+/large/flickr/cascaded/model.ckpt by using example.py(However, it loads base model, there is no error). It occurs following error:

Using cache found in /home/marco/.cache/torch/hub/s3prl_cache/4a54d64fa42b41e39db994c958d8107d5785a100f38c6eba680b6a3cc79babb3 for https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt WARNING:avssl.module.clip_official:Reduce text embedding to size of 8112 Traceback (most recent call last): File "/home/marco/Documents/human-gesture-generation/Bechmark/SpeechCLIP_plus/example.py", line 10, in model = avssl.model.KWClip_GeneralTransformer.load_from_checkpoint(model_fp) File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint model = cls._load_model_state(checkpoint, strict=strict, **kwargs) File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 204, in _load_model_state keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) File "/home/marco/.conda/envs/emagepy38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for KWClip_GeneralTransformer: size mismatch for criterion.eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]). size mismatch for criterion.neg_eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]). size mismatch for criterion.eye_mat_fl: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]).

Any insights or suggestions you can provide would be greatly appreciated.

Thank you!

ShampooWang commented 4 months ago

Hi,

In avssl/module/losses.py on line 126, there is a variable called MAX_EYE, which must be manually modified if you load models of different sizes. For the base models, MAX_EYE=256, and for the large models, MAX_EYE=1024. Thanks!