LAION-AI / CLAP

Contrastive Language-Audio Pretraining
https://arxiv.org/abs/2211.06687
Creative Commons Zero v1.0 Universal
1.42k stars 137 forks source link

can't load the pretrained checkpoints #137

Closed naarkhoo closed 10 months ago

naarkhoo commented 10 months ago

when I try to finetune CLAP w.r.t HF pipeline against the GTZAN data I get

"ValueError: Unrecognized configuration class <class 'transformers.models.clap.configuration_clap.ClapConfig'> for this kind of AutoModel: AutoModelForAudioClassification.
Model type should be one of ASTConfig, Data2VecAudioConfig, HubertConfig, SEWConfig, SEWDConfig, UniSpeechConfig, UniSpeechSatConfig, Wav2Vec2Config, Wav2Vec2ConformerConfig, WavLMConfig, WhisperConfig."

I am using laion/larger_clap_music_and_speech


from transformers import AutoModelForAudioClassification

num_labels = len(id2label)

model = AutoModelForAudioClassification.from_pretrained(
    model_id,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

when I try to load the pretrained checkpoints like

import laion_clap
model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')
model.load_ckpt('/Users/alka/Downloads/music_audioset_epoch_15_esc_90.14.pt')

I get

RuntimeError                              Traceback (most recent call last)
Cell In[27], [line 3](vscode-notebook-cell:?execution_count=27&line=3)
      [1](vscode-notebook-cell:?execution_count=27&line=1) import laion_clap
      [2](vscode-notebook-cell:?execution_count=27&line=2) model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')
----> [3](vscode-notebook-cell:?execution_count=27&line=3) model.load_ckpt('[/Users/alka/Downloads/music_audioset_epoch_15_esc_90.14.pt](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Downloads/music_audioset_epoch_15_esc_90.14.pt)')

File [~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:114](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:114), in CLAP_Module.load_ckpt(self, ckpt, model_id)
    [112](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:112) print('Load Checkpoint...')
    [113](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:113) ckpt = load_state_dict(ckpt, skip_params=True)
--> [114](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:114) self.model.load_state_dict(ckpt)
    [115](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:115) param_names = [n for n, p in self.model.named_parameters()]
    [116](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/laion_clap/hook.py:116) for n in param_names:

File [~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2152](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2152), in Module.load_state_dict(self, state_dict, strict, assign)
   [2147](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2147)         error_msgs.insert(
   [2148](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2148)             0, 'Missing key(s) in state_dict: {}. '.format(
   [2149](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2149)                 ', '.join(f'"{k}"' for k in missing_keys)))
   [2151](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2151) if len(error_msgs) > 0:
-> [2152](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2152)     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   [2153](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2153)                        self.__class__.__name__, "\n\t".join(error_msgs)))
   [2154](https://file+.vscode-resource.vscode-cdn.net/Users/alka/Devel/audiobench/notebook/~/Library/Caches/pypoetry/virtualenvs/musicbench-PHwhvKNx-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:2154) return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for CLAP:
    Unexpected key(s) in state_dict: "text_branch.embeddings.position_ids".
naarkhoo commented 10 months ago

had to downgrade the transformer library to 4.30.2