coqui-ai / TTS

🐸💬 - a deep learning toolkit for Text-to-Speech, battle-tested in research and production
http://coqui.ai
Mozilla Public License 2.0
35.09k stars 4.28k forks source link

[Bug] Mismatch type for resblock_type_decoder (hifigan_generator) #1068

Closed vince62s closed 2 years ago

vince62s commented 2 years ago

Hello,

While testing a notebook using the config.json (which contains resblock_type_decoder:1) and after loading the model From: https://drive.google.com/uc?id=1sgEjHt0lbPSEw9-FSbC_mBoOPwNi87YR To: /content/best_model.pth.tar it contains type 2 resblock_decoder the issue must come from the type "str" here https://github.com/coqui-ai/TTS/blob/main/TTS/vocoder/models/hifigan_generator.py#L203 while when loading the config.json it gets an "int"

To Reproduce

# model vars 
MODEL_PATH = 'best_model.pth.tar'
CONFIG_PATH = 'config.json'
TTS_LANGUAGES = "language_ids.json"
TTS_SPEAKERS = "speakers.json"
USE_CUDA = torch.cuda.is_available()

# load the config
C = load_config(CONFIG_PATH)

print("# load the audio processor")
ap = AudioProcessor(**C.audio)

speaker_embedding = None

C.model_args['d_vector_file'] = TTS_SPEAKERS
C.model_args['use_speaker_encoder_as_loss'] = False

language_manager = LanguageManager(TTS_LANGUAGES)
speaker_manager = SpeakerManager()

model = setup_model(C, language_manager=language_manager, speaker_manager=speaker_manager)
model.load_checkpoint(C, MODEL_PATH)

cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
print("# remove speaker encoder")
model_weights = cp['model'].copy()

for key in list(model_weights.keys()):
  if "speaker_encoder" in key:
    del model_weights[key]

model.load_state_dict(model_weights)
erogol commented 2 years ago

what notebook were you testing? "resblock_type" is str in the hifigan_config.py as well as in the released models. I could not figure what caused the issue exactly.

Can you also share config.json?

vince62s commented 2 years ago

Maybe my assumption is wrong but I am getting an error if I don't switch the config.json reblock_type_decoder from 1 to 2

the file is this one: Downloading... From: https://drive.google.com/uc?id=1-PfXD66l1ZpsZmJiC-vhL055CDSugLyP To: /content/config.json

from this notebook: https://colab.research.google.com/drive/1ftI0x16iqKgiQFgTjTDgRpOM1wC1U-yS?usp=sharing#scrollTo=yICxxOSZWYJb

I just adapted a few lines to match the current code base. but if you look in the config.json you will see "resblock_type_decoder": 1

Error is here at loading state dict:

RuntimeError: Error(s) in loading state_dict for Vits: Missing key(s) in state_dict: "waveform_decoder.resblocks.0.convs1.0.bias", "waveform_decoder.resblocks.0.convs1.0.weight_g", "waveform_decoder.resblocks.0.convs1.0.weight_v", "waveform_decoder.resblocks.0.convs1.1.bias", "waveform_decoder.resblocks.0.convs1.1.weight_g", "waveform_decoder.resblocks.0.convs1.1.weight_v", "waveform_decoder.resblocks.0.convs1.2.bias", "waveform_decoder.resblocks.0.convs1.2.weight_g", "waveform_decoder.resblocks.0.convs1.2.weight_v", "waveform_decoder.resblocks.0.convs2.0.bias", "waveform_decoder.resblocks.0.convs2.0.weight_g", "waveform_decoder.resblocks.0.convs2.0.weight_v", "waveform_decoder.resblocks.0.convs2.1.bias", "waveform_decoder.resblocks.0.convs2.1.weight_g", "waveform_decoder.resblocks.0.convs2.1.weight_v", "waveform_decoder.resblocks.0.convs2.2.bias", "waveform_decoder.resblocks.0.convs2.2.weight_g", "waveform_decoder.resblocks.0.convs2.2.weight_v", "waveform_decoder.resblocks.1.convs1.0.bias", "waveform_decoder.resblocks.1.convs1.0.weight_g", "waveform_decoder.resblocks.1.convs1.0.weight_v", "waveform_decoder.resblocks.1.convs1.1.bias", "waveform_decoder.resblocks.1.convs1.1.weight_g", "waveform_decoder.resblocks.1.convs1.1.weight_v", "waveform_decoder.resblocks.1.convs1.2.bias", "waveform_decoder.resblocks.1.convs1.2.weight_g", "waveform_decoder.resblocks.1.convs1.2.weight_v", "waveform_decoder.resblocks.1.convs2.0.bias", "waveform_decoder.resblocks.1.convs2.0.weight_g", "waveform_decoder.resblocks.1.convs2.0.weight_v", "waveform_decoder.resblocks.1.convs2.1.bias", "waveform_decoder.resblocks.1.convs2.1.weight_g", "waveform_decoder.resblocks.1.convs2.1.weight_v", "waveform_decoder.resblocks.1.convs2.2.bias", "waveform_decoder.resblocks.1.convs2.2.weight_g", "waveform_decoder.resblocks.1.convs2.2.weight_v", "waveform_decoder.resblocks.2.convs1.0.bias", "waveform_decoder.resblocks.2.convs1.0.weight_g", "waveform_decoder.resblocks.2.convs1.0.weight_v", "waveform_decoder.resblocks.2.convs1.1.bias", "waveform_decoder.resblocks.2.convs1.1.weight_g", "waveform_decoder.resblocks.2.convs1.1.weight_v", "waveform_decoder.resblocks.2.convs1.2.bias", "waveform_decoder.resblocks.2.convs1.2.weight_g", "waveform_decoder.resblocks.2.convs1.2.weight_v", "waveform_decoder.resblocks.2.convs2.0.bias", "waveform_decoder.resblocks.2.convs2.0.weight_g", "waveform_decoder.resblocks.2.convs2.0.weight_v", "waveform_decoder.resblocks.2.convs2.1.bias", "waveform_decoder.resblocks.2.convs2.1.weight_g", "waveform_decoder.resblocks.2.convs2.1.weight_v", "waveform_decoder.resblocks.2.convs2.2.bias", "waveform_decoder.resblocks.2.convs2.2.weight_g", "waveform_decoder.resblocks.2.convs2.2.weight_v", "waveform_decoder.resblocks.3.convs1.0.bias", "waveform_decoder.resblocks.3.convs1.0.weight_g", "waveform_decoder.resblocks.3.convs1.0.weight_v", "waveform_decoder.resblocks.3.convs1.1.bias", "waveform_decoder.resblocks.3.convs1.1.weight_g", "waveform_decoder.resblocks.3.convs1.1.weight_v", "waveform_decoder.resblocks.3.convs1.2.bias", "waveform_decoder.resblocks.3.convs1.2.weight_g", "waveform_decoder.resblocks.3.convs1.2.weight_v", "waveform_decoder.resblocks.3.convs2.0.bias", "waveform_decoder.resblocks.3.convs2.0.weight_g", "waveform_decoder.resblocks.3.convs2.0.weight_v", "waveform_decoder.resblocks.3.convs2.1.bias", "waveform_decoder.resblocks.3.convs2.1.weight_g", "waveform_decoder.resblocks.3.convs2.1.weight_v", "waveform_decoder.resblocks.3.convs2.2.bias", "waveform_decoder.resblocks.3.convs2.2.weight_g", "waveform_decoder.resblocks.3.convs2.2.weight_v", "waveform_decoder.resblocks.4.convs1.0.bias", "waveform_decoder.resblocks.4.convs1.0.weight_g", "waveform_decoder.resblocks.4.convs1.0.weight_v", "waveform_decoder.resblocks.4.convs1.1.bias", "waveform_decoder.resblocks.4.convs1.1.weight_g", "waveform_decoder.resblocks.4.convs1.1.weight_v", "waveform_decoder.resblocks.4.convs1.2.bias", "waveform_decoder.resblocks.4.convs1.2.weight_g", "waveform_decoder.resblocks.4.convs1.2.weight_v", "waveform_decoder.resblocks.4.convs2.0.bias", "waveform_decoder.resblocks.4.convs2.0.weight_g", "waveform_decoder.resblocks.4.convs2.0.weight_v", "waveform_decoder.resblocks.4.convs2.1.bias", "waveform_decoder.resblocks.4.convs2.1.weight_g", "waveform_decoder.resblocks.4.convs2.1.weight_v", "waveform_decoder.resblocks.4.convs2.2.bias", "waveform_decoder.resblocks.4.convs2.2.weight_g", "waveform_decoder.resblocks.4.convs2.2.weight_v", "waveform_decoder.resblocks.5.convs1.0.bias", "waveform_decoder.resblocks.5.convs1.0.weight_g", "waveform_decoder.resblocks.5.convs1.0.weight_v", "waveform_decoder.resblocks.5.convs1.1.bias", "waveform_decoder.resblocks.5.convs1.1.weight_g", "waveform_decoder.resblocks.5.convs1.1.weight_v", "waveform_decoder.resblocks.5.convs1.2.bias", "waveform_decoder.resblocks.5.convs1.2.weight_g", "waveform_decoder.resblocks.5.convs1.2.weight_v", "waveform_decoder.resblocks.5.convs2.0.bias", "waveform_decoder.resblocks.5.convs2.0.weight_g", "waveform_decoder.resblocks.5.convs2.0.weight_v", "waveform_decoder.resblocks.5.convs2.1.bias", "waveform_decoder.resblocks.5.convs2.1.weight_g", "waveform_decoder.resblocks.5.convs2.1.weight_v", "waveform_decoder.resblocks.5.convs2.2.bias", "waveform_decoder.resblocks.5.convs2.2.weight_g", "waveform_decoder.resblocks.5.convs2.2.weight_v", "waveform_decoder.resblocks.6.convs1.0.bias", "waveform_decoder.resblocks.6.convs1.0.weight_g", "waveform_decoder.resblocks.6.convs1.0.weight_v", "waveform_decoder.resblocks.6.convs1.1.bias", "waveform_decoder.resblocks.6.convs1.1.weight_g", "waveform_decoder.resblocks.6.convs1.1.weight_v", "waveform_decoder.resblocks.6.convs1.2.bias", "waveform_decoder.resblocks.6.convs1.2.weight_g", "waveform_decoder.resblocks.6.convs1.2.weight_v", "waveform_decoder.resblocks.6.convs2.0.bias", "waveform_decoder.resblocks.6.convs2.0.weight_g", "waveform_decoder.resblocks.6.convs2.0.weight_v", "waveform_decoder.resblocks.6.convs2.1.bias", "waveform_decoder.resblocks.6.convs2.1.weight_g", "waveform_decoder.resblocks.6.convs2.1.weight_v", "waveform_decoder.resblocks.6.convs2.2.bias", "waveform_decoder.resblocks.6.convs2.2.weight_g", "waveform_decoder.resblocks.6.convs2.2.weight_v", "waveform_decoder.resblocks.7.convs1.0.bias", "waveform_decoder.resblocks.7.convs1.0.weight_g", "waveform_decoder.resblocks.7.convs1.0.weight_v", "waveform_decoder.resblocks.7.convs1.1.bias", "waveform_decoder.resblocks.7.convs1.1.weight_g", "waveform_decoder.resblocks.7.convs1.1.weight_v", "waveform_decoder.resblocks.7.convs1.2.bias", "waveform_decoder.resblocks.7.convs1.2.weight_g", "waveform_decoder.resblocks.7.convs1.2.weight_v", "waveform_decoder.resblocks.7.convs2.0.bias", "waveform_decoder.resblocks.7.convs2.0.weight_g", "waveform_decoder.resblocks.7.convs2.0.weight_v", "waveform_decoder.resblocks.7.convs2.1.bias", "waveform_decoder.resblocks.7.convs2.1.weight_g", "waveform_decoder.resblocks.7.convs2.1.weight_v", "waveform_decoder.resblocks.7.convs2.2.bias", "waveform_decoder.resblocks.7.convs2.2.weight_g", "waveform_decoder.resblocks.7.convs2.2.weight_v", "waveform_decoder.resblocks.8.convs1.0.bias", "waveform_decoder.resblocks.8.convs1.0.weight_g", "waveform_decoder.resblocks.8.convs1.0.weight_v", "waveform_decoder.resblocks.8.convs1.1.bias", "waveform_decoder.resblocks.8.convs1.1.weight_g", "waveform_decoder.resblocks.8.convs1.1.weight_v", "waveform_decoder.resblocks.8.convs1.2.bias", "waveform_decoder.resblocks.8.convs1.2.weight_g", "waveform_decoder.resblocks.8.convs1.2.weight_v", "waveform_decoder.resblocks.8.convs2.0.bias", "waveform_decoder.resblocks.8.convs2.0.weight_g", "waveform_decoder.resblocks.8.convs2.0.weight_v", "waveform_decoder.resblocks.8.convs2.1.bias", "waveform_decoder.resblocks.8.convs2.1.weight_g", "waveform_decoder.resblocks.8.convs2.1.weight_v", "waveform_decoder.resblocks.8.convs2.2.bias", "waveform_decoder.resblocks.8.convs2.2.weight_g", "waveform_decoder.resblocks.8.convs2.2.weight_v", "waveform_decoder.resblocks.9.convs1.0.bias", "waveform_decoder.resblocks.9.convs1.0.weight_g", "waveform_decoder.resblocks.9.convs1.0.weight_v", "waveform_decoder.resblocks.9.convs1.1.bias", "waveform_decoder.resblocks.9.convs1.1.weight_g", "waveform_decoder.resblocks.9.convs1.1.weight_v", "waveform_decoder.resblocks.9.convs1.2.bias", "waveform_decoder.resblocks.9.convs1.2.weight_g", "waveform_decoder.resblocks.9.convs1.2.weight_v", "waveform_decoder.resblocks.9.convs2.0.bias", "waveform_decoder.resblocks.9.convs2.0.weight_g", "waveform_decoder.resblocks.9.convs2.0.weight_v", "waveform_decoder.resblocks.9.convs2.1.bias", "waveform_decoder.resblocks.9.convs2.1.weight_g", "waveform_decoder.resblocks.9.convs2.1.weight_v", "waveform_decoder.resblocks.9.convs2.2.bias", "waveform_decoder.resblocks.9.convs2.2.weight_g", "waveform_decoder.resblocks.9.convs2.2.weight_v", "waveform_decoder.resblocks.10.convs1.0.bias", "waveform_decoder.resblocks.10.convs1.0.weight_g", "waveform_decoder.resblocks.10.convs1.0.weight_v", "waveform_decoder.resblocks.10.convs1.1.bias", "waveform_decoder.resblocks.10.convs1.1.weight_g", "waveform_decoder.resblocks.10.convs1.1.weight_v", "waveform_decoder.resblocks.10.convs1.2.bias", "waveform_decoder.resblocks.10.convs1.2.weight_g", "waveform_decoder.resblocks.10.convs1.2.weight_v", "waveform_decoder.resblocks.10.convs2.0.bias", "waveform_decoder.resblocks.10.convs2.0.weight_g", "waveform_decoder.resblocks.10.convs2.0.weight_v", "waveform_decoder.resblocks.10.convs2.1.bias", "waveform_decoder.resblocks.10.convs2.1.weight_g", "waveform_decoder.resblocks.10.convs2.1.weight_v", "waveform_decoder.resblocks.10.convs2.2.bias", "waveform_decoder.resblocks.10.convs2.2.weight_g", "waveform_decoder.resblocks.10.convs2.2.weight_v", "waveform_decoder.resblocks.11.convs1.0.bias", "waveform_decoder.resblocks.11.convs1.0.weight_g", "waveform_decoder.resblocks.11.convs1.0.weight_v", "waveform_decoder.resblocks.11.convs1.1.bias", "waveform_decoder.resblocks.11.convs1.1.weight_g", "waveform_decoder.resblocks.11.convs1.1.weight_v", "waveform_decoder.resblocks.11.convs1.2.bias", "waveform_decoder.resblocks.11.convs1.2.weight_g", "waveform_decoder.resblocks.11.convs1.2.weight_v", "waveform_decoder.resblocks.11.convs2.0.bias", "waveform_decoder.resblocks.11.convs2.0.weight_g", "waveform_decoder.resblocks.11.convs2.0.weight_v", "waveform_decoder.resblocks.11.convs2.1.bias", "waveform_decoder.resblocks.11.convs2.1.weight_g", "waveform_decoder.resblocks.11.convs2.1.weight_v", "waveform_decoder.resblocks.11.convs2.2.bias", "waveform_decoder.resblocks.11.convs2.2.weight_g", "waveform_decoder.resblocks.11.convs2.2.weight_v". Unexpected key(s) in state_dict: "waveform_decoder.resblocks.0.convs.0.bias", "waveform_decoder.resblocks.0.convs.0.weight_g", "waveform_decoder.resblocks.0.convs.0.weight_v", "waveform_decoder.resblocks.0.convs.1.bias", "waveform_decoder.resblocks.0.convs.1.weight_g", "waveform_decoder.resblocks.0.convs.1.weight_v", "waveform_decoder.resblocks.1.convs.0.bias", "waveform_decoder.resblocks.1.convs.0.weight_g", "waveform_decoder.resblocks.1.convs.0.weight_v", "waveform_decoder.resblocks.1.convs.1.bias", "waveform_decoder.resblocks.1.convs.1.weight_g", "waveform_decoder.resblocks.1.convs.1.weight_v", "waveform_decoder.resblocks.2.convs.0.bias", "waveform_decoder.resblocks.2.convs.0.weight_g", "waveform_decoder.resblocks.2.convs.0.weight_v", "waveform_decoder.resblocks.2.convs.1.bias", "waveform_decoder.resblocks.2.convs.1.weight_g", "waveform_decoder.resblocks.2.convs.1.weight_v", "waveform_decoder.resblocks.3.convs.0.bias", "waveform_decoder.resblocks.3.convs.0.weight_g", "waveform_decoder.resblocks.3.convs.0.weight_v", "waveform_decoder.resblocks.3.convs.1.bias", "waveform_decoder.resblocks.3.convs.1.weight_g", "waveform_decoder.resblocks.3.convs.1.weight_v", "waveform_decoder.resblocks.4.convs.0.bias", "waveform_decoder.resblocks.4.convs.0.weight_g", "waveform_decoder.resblocks.4.convs.0.weight_v", "waveform_decoder.resblocks.4.convs.1.bias", "waveform_decoder.resblocks.4.convs.1.weight_g", "waveform_decoder.resblocks.4.convs.1.weight_v", "waveform_decoder.resblocks.5.convs.0.bias", "waveform_decoder.resblocks.5.convs.0.weight_g", "waveform_decoder.resblocks.5.convs.0.weight_v", "waveform_decoder.resblocks.5.convs.1.bias", "waveform_decoder.resblocks.5.convs.1.weight_g", "waveform_decoder.resblocks.5.convs.1.weight_v", "waveform_decoder.resblocks.6.convs.0.bias", "waveform_decoder.resblocks.6.convs.0.weight_g", "waveform_decoder.resblocks.6.convs.0.weight_v", "waveform_decoder.resblocks.6.convs.1.bias", "waveform_decoder.resblocks.6.convs.1.weight_g", "waveform_decoder.resblocks.6.convs.1.weight_v", "waveform_decoder.resblocks.7.convs.0.bias", "waveform_decoder.resblocks.7.convs.0.weight_g", "waveform_decoder.resblocks.7.convs.0.weight_v", "waveform_decoder.resblocks.7.convs.1.bias", "waveform_decoder.resblocks.7.convs.1.weight_g", "waveform_decoder.resblocks.7.convs.1.weight_v", "waveform_decoder.resblocks.8.convs.0.bias", "waveform_decoder.resblocks.8.convs.0.weight_g", "waveform_decoder.resblocks.8.convs.0.weight_v", "waveform_decoder.resblocks.8.convs.1.bias", "waveform_decoder.resblocks.8.convs.1.weight_g", "waveform_decoder.resblocks.8.convs.1.weight_v", "waveform_decoder.resblocks.9.convs.0.bias", "waveform_decoder.resblocks.9.convs.0.weight_g", "waveform_decoder.resblocks.9.convs.0.weight_v", "waveform_decoder.resblocks.9.convs.1.bias", "waveform_decoder.resblocks.9.convs.1.weight_g", "waveform_decoder.resblocks.9.convs.1.weight_v", "waveform_decoder.resblocks.10.convs.0.bias", "waveform_decoder.resblocks.10.convs.0.weight_g", "waveform_decoder.resblocks.10.convs.0.weight_v", "waveform_decoder.resblocks.10.convs.1.bias", "waveform_decoder.resblocks.10.convs.1.weight_g", "waveform_decoder.resblocks.10.convs.1.weight_v", "waveform_decoder.resblocks.11.convs.0.bias", "waveform_decoder.resblocks.11.convs.0.weight_g", "waveform_decoder.resblocks.11.convs.0.weight_v", "waveform_decoder.resblocks.11.convs.1.bias", "waveform_decoder.resblocks.11.convs.1.weight_g", "waveform_decoder.resblocks.11.convs.1.weight_v".

vince62s commented 2 years ago

okay I know what happened, the issue comes from Edresson's branch

even when the json file contains resblock type 1 it prints as: resblock_type: 1 <class 'int'> when building the model, since it's an INT, then the test is wrong in the hifigan_generator hence the checkpoint that he has built is in fact a type 2

conclusion: just need to adjust the config.json with type 2 so that the notebook works.