atosystem / SpeechCLIP

SpeechCLIP: Integrating Speech with Pre-Trained Vision and Language Model, Accepted to IEEE SLT 2022
https://atosystem.github.io/blogs/speechclip
BSD 3-Clause "New" or "Revised" License
108 stars 6 forks source link

Can derive embeddings with base but not large #7

Closed lokesh12345678910 closed 5 months ago

lokesh12345678910 commented 5 months ago

In example.py, I get the following error when I substitute model_fp = "slt_ckpts/SpeechCLIP/base/flickr/parallel/epoch_131-step_15443-val_recall_mean_1_36.0100.ckpt" with model_fp = "slt_ckpts/SpeechCLIP/large/flickr/parallel/epoch_56-step_6668-val_recall_mean_10_89.0000.ckpt"

I also get the same error for model_fp = "slt_ckpts/SpeechCLIP/large/coco/parallel/epoch_14-step_33224-val_recall_mean_10_84.0128.ckpt"

Traceback (most recent call last): File "/work/07469/lpugalen/ls6/SpeechCLIP/example.py", line 37, in largeFlickrParallelModel = avssl.model.KWClip_GeneralTransformer.load_from_checkpoint(largeFlickrParallelModelPath).to(device) File "/work/07469/lpugalen/ls6/SpeechCLIP/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint model = cls._load_model_state(checkpoint, strict=strict, **kwargs) File "/work/07469/lpugalen/ls6/SpeechCLIP/pytorch_lightning/core/saving.py", line 204, in _load_model_state keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) File "/work/07469/lpugalen/ls6/SpeechCLIP/torch/nn/modules/module.py", line 2153, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for KWClip_GeneralTransformer: size mismatch for criterion.eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]). size mismatch for criterion.neg_eye_mat: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256, 256]). size mismatch for criterion.eye_mat_fl: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([256,256]).

atosystem commented 5 months ago

Hi @lokesh12345678910, I'm sorry for the bug. You may have to do a little change https://github.com/atosystem/SpeechCLIP/blob/839cd2bb38ab0485bb0c1209dd84e97e3f960a36/avssl/module/losses.py#L126

Change to MAX_EYE = 1024 when this error occurred.

lokesh12345678910 commented 5 months ago

Thank you, this worked!