dkurzend / ClipClap-GZSL

Audio-Visual Generalized Zero-Shot Learning using Large Pre-Trained Models
MIT License
12 stars 1 forks source link

Could you provide the weights of the models for extracting features from scratch? #1

Open AndreJJXu opened 4 months ago

AndreJJXu commented 4 months ago

In your section "Extracting Features from Scratch", I find that you have leveraged the pre-trained models fine-tuned by yourself. Since I want to run the whole structure of your work, can you provide these weights or provide more details about "Extracting Features from Scratch"? Thanks!

dkurzend commented 4 months ago

Hi @AndreJJXu, I finetuned the feature extraction models during my research, therefore you see if args.finetuned_model == True: in the code. However for the paper, I did not finetune them, so you can ignore this.

To extract eg the features for UCF by yourself, you would have to run python clip_feature_extraction/get_clip_features_ucf.py --finetuned_model False. Also, you would have to adjust the paths (for the dataset, save_path, wavcaps paths, etc.) in the script. I hope that helps.

AndreJJXu commented 4 months ago

When I want to load weight from the files downloaded from "https://github.com/XinhaoMei/WavCaps", specifically for the "WavCaps/retrieval/pretrained_models/audio_encoders/HTSAT_BERT_zero_shot.pt", I always get the error { RuntimeError: Error(s) in loading state_dict for ASE: Unexpected key(s) in state_dict: "text_encoder.text_encoder.embeddings.position_ids". } I rebuilt my conda environment, but also got this problem. That made me crazy, could you tell me how to get rid of this problem?

dkurzend commented 3 months ago

Hi, did you use the right conda environment? I created a separate conda environment for the feature extraction: conda env create -f clipclap_feature_extraction.yml.

Also, you have to adjust the model path in the scripts where the features are created. For UCF it would be clip_feature_extraction/get_clip_features_ucf.py in line 121:

else:
    cp_path = '/home/aoq234/dev/CLIP-GZSL/WavCaps/retrieval/pretrained_models/audio_encoders/HTSAT_BERT_zero_shot.pt' # <- adjust this path
    state_dict_key = 'model'

cp = torch.load(cp_path)
wavcaps_model.load_state_dict(cp[state_dict_key])
wavcaps_model.eval()
print("Model weights loaded from {}".format(cp_path))
carankt commented 1 month ago

@AndreJJXu I was facing a similar problem, I used https://github.com/XinhaoMei/WavCaps/blob/master/retrieval/work.yaml and created a new env. That solved the issue for me