microsoft / Pengi

An Audio Language model for Audio Tasks
https://arxiv.org/abs/2305.11834
MIT License
281 stars 15 forks source link

Not able to execute pengi.generate and pengi.decribe #9

Closed radhavishnu closed 8 months ago

radhavishnu commented 9 months ago

On successfully pip installing and running

from wrapper import PengiWrapper as Pengi

pengi = Pengi(config="base")

generated_response = pengi.generate(audio_paths='/content/Robin.mp3',
                                    text_prompts=["generate metadata"], 
                                    add_texts=[""], 
                                    max_len=30, 
                                    beam_size=3, 
                                    temperature=1.0, 
                                    stop_token=' <|endoftext|>'
                                    )

the following error message comes

RuntimeError                              Traceback (most recent call last)
/content/Pengi/wrapper.py in get_model_and_tokenizer(self, config_path)
     93         try:
---> 94             model.load_state_dict(model_state_dict)
     95         except:

4 frames
RuntimeError: Error(s) in loading state_dict for PENGI:
    Unexpected key(s) in state_dict: "caption_encoder.base.embeddings.position_ids", "caption_decoder.gpt.transformer.h.0.attn.bias", "caption_decoder.gpt.transformer.h.0.attn.masked_bias", "caption_decoder.gpt.transformer.h.1.attn.bias", "caption_decoder.gpt.transformer.h.1.attn.masked_bias", "caption_decoder.gpt.transformer.h.2.attn.bias", "caption_decoder.gpt.transformer.h.2.attn.masked_bias", "caption_decoder.gpt.transformer.h.3.attn.bias", "caption_decoder.gpt.transformer.h.3.attn.masked_bias", "caption_decoder.gpt.transformer.h.4.attn.bias", "caption_decoder.gpt.transformer.h.4.attn.masked_bias", "caption_decoder.gpt.transformer.h.5.attn.bias", "caption_decoder.gpt.transformer.h.5.attn.masked_bias", "caption_decoder.gpt.transformer.h.6.attn.bias", "caption_decoder.gpt.transformer.h.6.attn.masked_bias", "caption_decoder.gpt.transformer.h.7.attn.bias", "caption_decoder.gpt.transformer.h.7.attn.masked_bias", "caption_decoder.gpt.transformer.h.8.attn.bias", "caption_decoder.gpt.transformer.h.8.attn.masked_bias", "caption_decoder.gpt.transformer.h.9.attn.bias", "caption_decoder.gpt.transformer.h.9.attn.masked_bias", "caption_decoder.gpt.transformer.h.10.attn.bias", "caption_decoder.gpt.transformer.h.10.attn.masked_bias", "caption_decoder.gpt.transformer.h.11.attn.bias", "caption_decoder.gpt.transformer.h.11.attn.masked_bias". 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
   2150 
   2151         if len(error_msgs) > 0:
-> 2152             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2154         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for PENGI:
    Missing key(s) in state_dict: "audio_encoder.base.htsat.spectrogram_extractor.stft.conv_real.weight", "audio_encoder.base.htsat.spectrogram_extractor.stft.conv_imag.weight", "audio_encoder.base.htsat.logmel_extractor.melW", "audio_encoder.base.htsat.bn0.weight", "audio_encoder.base.htsat.bn0.bias", "audio_encoder.base.htsat.bn0.running_mean", "audio_encoder.base.htsat.bn0.running_var", "audio_encoder.base.htsat.patch_embed.proj.weight", "audio_encoder.base.htsat.patch_embed.proj.bias", "audio_encoder.base.htsat.patch_embed.norm.weight", "audio_encoder.base.htsat.patch_embed.norm.bias", "audio_encoder.base.htsat.layers.0.blocks.0.norm1.weight", "audio_encoder.base.htsat.layers.0.blocks.0.norm1.bias", "audio_encoder.base.htsat.layers.0.blocks.0.attn.relative_position_bias_table", "audio_encoder.base.htsat.layers.0.blocks.0.attn.relative_position_index", "audio_encoder.base.htsat.layers.0.blocks.0.attn.qkv.weight", "audio_encoder.base.htsat.layers.0.blocks.0.attn.qkv.bias", "audio_encoder.base.htsat.layers.0.blocks.0.attn.proj.weight", "audio_encoder.base.htsat.layers.0.blocks.0.attn.proj.bias", "audio_encoder.base.htsat.layers.0.blocks.0.norm2.weight", "audio_encoder.base.htsat.layers.0.blocks.0.norm2.bias", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc1.weight", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc1.bias", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc2.weight", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc2.bias", "audio_encoder.base.htsat.laye...
    Unexpected key(s) in state_dict: "ncoder.base.htsat.spectrogram_extractor.stft.conv_real.weight", "ncoder.base.htsat.spectrogram_extractor.stft.conv_imag.weight", "ncoder.base.htsat.logmel_extractor.melW", "ncoder.base.htsat.bn0.weight", "ncoder.base.htsat.bn0.bias", "ncoder.base.htsat.bn0.running_mean", "ncoder.base.htsat.bn0.running_var", "ncoder.base.htsat.bn0.num_batches_tracked", "ncoder.base.htsat.patch_embed.proj.weight", "ncoder.base.htsat.patch_embed.proj.bias", "ncoder.base.htsat.patch_embed.norm.weight", "ncoder.base.htsat.patch_embed.norm.bias", "ncoder.base.htsat.layers.0.blocks.0.norm1.weight", "ncoder.base.htsat.layers.0.blocks.0.norm1.bias", "ncoder.base.htsat.layers.0.blocks.0.attn.relative_position_bias_table", "ncoder.base.htsat.layers.0.blocks.0.attn.relative_position_index", "ncoder.base.htsat.layers.0.blocks.0.attn.qkv.weight", "ncoder.base.htsat.layers.0.blocks.0.attn.qkv.bias", "ncoder.base.htsat.layers.0.blocks.0.attn.proj.weight", "ncoder.base.htsat.layers.0.blocks.0.attn.proj.bias", "ncoder.base.htsat.layers.0.blocks.0.norm2.weight", "ncoder.base.htsat.layers.0.blocks.0.norm2.bias", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc1.weight", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc1.bias", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc2.weight", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc2.bias", "ncoder.base.htsat.layers.0.blocks.1.attn_mask", "ncoder.base.htsat.layers.0.blocks.1.norm1.weight", "ncoder.base.htsat.layers.0.blocks.1.norm1.bias", "ncode...
soham97 commented 9 months ago

Have you downloaded the checkpoints and moved them to the configs folder?

radhavishnu commented 9 months ago

Yes downloaded and saved as base.pth and base_no_text_enc.pth in the config folder

soham97 commented 9 months ago

Two issues:

  1. There might be a file name mismatch at your end. I can reproduce your error when I switch weight file names i.e. rename base_no_text_enc.pth to base.pth. Maybe there is a file name switch during the copy at your end?
  2. Unrelated, but the audio file has to be passed as a list: audio_paths=['/content/Robin.mp3']