manmay-nakhashi / tortoise-tts-fastest

Faster Tortoise inference then Tortoise Fast Fork
GNU Affero General Public License v3.0
122 stars 9 forks source link

Error encountered while creating TextToSpeech instance in tortoise_tts.ipynb #9

Closed RahulBhalley closed 10 months ago

RahulBhalley commented 11 months ago
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 14
     11 from tortoise.utils.audio import load_audio, load_voice, load_voices
     13 # This will download all the models used by Tortoise from the HuggingFace hub.
---> 14 tts = TextToSpeech()

File /workspace/tortoise-tts-fast/tortoise/api.py:271, in TextToSpeech.__init__(self, autoregressive_batch_size, models_dir, enable_redaction, device, high_vram, kv_cache, ar_checkpoint, clvp_checkpoint, diff_checkpoint, vocoder)
    254 self.autoregressive = (
    255     UnifiedVoice(
    256         max_mel_tokens=604,
   (...)
    268     .eval()
    269 )
    270 ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
--> 271 self.autoregressive.load_state_dict(torch.load(ar_path))
    272 self.autoregressive.post_init_gpt2_config(kv_cache)
    274 diff_path = diff_checkpoint or get_model_path(
    275     "diffusion_decoder.pth", models_dir
    276 )

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for UnifiedVoice:
    Unexpected key(s) in state_dict: "gpt.h.0.attn.bias", "gpt.h.0.attn.masked_bias", "gpt.h.1.attn.bias", "gpt.h.1.attn.masked_bias", "gpt.h.2.attn.bias", "gpt.h.2.attn.masked_bias", "gpt.h.3.attn.bias", "gpt.h.3.attn.masked_bias", "gpt.h.4.attn.bias", "gpt.h.4.attn.masked_bias", "gpt.h.5.attn.bias", "gpt.h.5.attn.masked_bias", "gpt.h.6.attn.bias", "gpt.h.6.attn.masked_bias", "gpt.h.7.attn.bias", "gpt.h.7.attn.masked_bias", "gpt.h.8.attn.bias", "gpt.h.8.attn.masked_bias", "gpt.h.9.attn.bias", "gpt.h.9.attn.masked_bias", "gpt.h.10.attn.bias", "gpt.h.10.attn.masked_bias", "gpt.h.11.attn.bias", "gpt.h.11.attn.masked_bias", "gpt.h.12.attn.bias", "gpt.h.12.attn.masked_bias", "gpt.h.13.attn.bias", "gpt.h.13.attn.masked_bias", "gpt.h.14.attn.bias", "gpt.h.14.attn.masked_bias", "gpt.h.15.attn.bias", "gpt.h.15.attn.masked_bias", "gpt.h.16.attn.bias", "gpt.h.16.attn.masked_bias", "gpt.h.17.attn.bias", "gpt.h.17.attn.masked_bias", "gpt.h.18.attn.bias", "gpt.h.18.attn.masked_bias", "gpt.h.19.attn.bias", "gpt.h.19.attn.masked_bias", "gpt.h.20.attn.bias", "gpt.h.20.attn.masked_bias", "gpt.h.21.attn.bias", "gpt.h.21.attn.masked_bias", "gpt.h.22.attn.bias", "gpt.h.22.attn.masked_bias", "gpt.h.23.attn.bias", "gpt.h.23.attn.masked_bias", "gpt.h.24.attn.bias", "gpt.h.24.attn.masked_bias", "gpt.h.25.attn.bias", "gpt.h.25.attn.masked_bias", "gpt.h.26.attn.bias", "gpt.h.26.attn.masked_bias", "gpt.h.27.attn.bias", "gpt.h.27.attn.masked_bias", "gpt.h.28.attn.bias", "gpt.h.28.attn.masked_bias", "gpt.h.29.attn.bias", "gpt.h.29.attn.masked_bias". 
DrBrule commented 10 months ago

Just went through a similar issue. This has to do with transformers :

pip install transformers==4.29.2 should fix it

RahulBhalley commented 10 months ago

Thanks @DrJackPayne.