microsoft / Pengi

An Audio Language model for Audio Tasks
https://arxiv.org/abs/2305.11834
MIT License
281 stars 15 forks source link

Unable to load pre-trained weights #11

Closed awkrail closed 7 months ago

awkrail commented 8 months ago

Hi, @soham97

Related to #9 (But cannot solve my error).

Thank you for sharing your great work! I downloaded the weights from Zendoo and put them in configs/ Then, I run the following codes.

from wrapper import PengiWrapper as Pengi
pengi = Pengi(config='base')

And got the missing_keys errors. I checked the md5 checksum but the values are correct. How can I solve this error? Thanks.

In [1]: from wrapper import PengiWrapper as Pengi

In [2]: pengi = Pengi(config='base')
/Users/tainishi/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3527.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.10k/4.10k [00:00<00:00, 2.56MB/s]
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 599M/599M [00:58<00:00, 10.2MB/s]
config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 665/665 [00:00<00:00, 3.50MB/s]
model.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 548M/548M [00:50<00:00, 10.9MB/s]
generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 124/124 [00:00<00:00, 271kB/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/Pengi/wrapper.py:94, in PengiWrapper.get_model_and_tokenizer(self, config_path)
     93 try:
---> 94     model.load_state_dict(model_state_dict)
     95 except:

File ~/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/nn/modules/module.py:2152, in Module.load_state_dict(self, state_dict, strict, assign)
   2151 if len(error_msgs) > 0:
-> 2152     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2154 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for PENGI:
    Unexpected key(s) in state_dict: "caption_encoder.base.embeddings.position_ids", "caption_decoder.gpt.transformer.h.0.attn.bias", "caption_decoder.gpt.transformer.h.0.attn.masked_bias", "caption_decoder.gpt.transformer.h.1.attn.bias", "caption_decoder.gpt.transformer.h.1.attn.masked_bias", "caption_decoder.gpt.transformer.h.2.attn.bias", "caption_decoder.gpt.transformer.h.2.attn.masked_bias", "caption_decoder.gpt.transformer.h.3.attn.bias", "caption_decoder.gpt.transformer.h.3.attn.masked_bias", "caption_decoder.gpt.transformer.h.4.attn.bias", "caption_decoder.gpt.transformer.h.4.attn.masked_bias", "caption_decoder.gpt.transformer.h.5.attn.bias", "caption_decoder.gpt.transformer.h.5.attn.masked_bias", "caption_decoder.gpt.transformer.h.6.attn.bias", "caption_decoder.gpt.transformer.h.6.attn.masked_bias", "caption_decoder.gpt.transformer.h.7.attn.bias", "caption_decoder.gpt.transformer.h.7.attn.masked_bias", "caption_decoder.gpt.transformer.h.8.attn.bias", "caption_decoder.gpt.transformer.h.8.attn.masked_bias", "caption_decoder.gpt.transformer.h.9.attn.bias", "caption_decoder.gpt.transformer.h.9.attn.masked_bias", "caption_decoder.gpt.transformer.h.10.attn.bias", "caption_decoder.gpt.transformer.h.10.attn.masked_bias", "caption_decoder.gpt.transformer.h.11.attn.bias", "caption_decoder.gpt.transformer.h.11.attn.masked_bias".
soham97 commented 8 months ago

Hi @awkrail thanks for the feedback!

I downloaded the weights again from Zenodo and git-cloned the repo. I can run the inference on Windows and Linux. Therefore, the issue does not seem to be in the weights (as verified by your md5 checksum) or inference code.

Can you share the Transformers library version you are using? Maybe the GPT2 model class in new version of Transformers has some changes

awkrail commented 7 months ago

@soham97 Thank you for your advice, I downgraded the transformers to 4.28.1 as with requirements.txt, and fixed the issue. I close this issue.