yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.98k stars 422 forks source link

Error Message After Using a fine tuned ASR Model #252

Open GUUser91 opened 5 months ago

GUUser91 commented 5 months ago

I get this error message after using a fine tuned ASR Model

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 5
      4 try:
----> 5     model[key].load_state_dict(params[key])
      6 except:

File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
   2188 if len(error_msgs) > 0:
-> 2189     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ASRCNN:
    Missing key(s) in state_dict: "to_mfcc.dct_mat", "init_cnn.conv.weight", "init_cnn.conv.bias", "cnns.0.0.blocks.0.0.conv.weight", "cnns.0.0.blocks.0.0.conv.bias", "cnns.0.0.blocks.0.2.weight", "cnns.0.0.blocks.0.2.bias", "cnns.0.0.blocks.0.4.conv.weight", "cnns.0.0.blocks.0.4.conv.bias", "cnns.0.0.blocks.1.0.conv.weight", "cnns.0.0.blocks.1.0.conv.bias", "cnns.0.0.blocks.1.2.weight", "cnns.0.0.blocks.1.2.bias", "cnns.0.0.blocks.1.4.conv.weight", "cnns.0.0.blocks.1.4.conv.bias", "cnns.0.0.blocks.2.0.conv.weight", "cnns.0.0.blocks.2.0.conv.bias", "cnns.0.0.blocks.2.2.weight", "cnns.0.0.blocks.2.2.bias", "cnns.0.0.blocks.2.4.conv.weight", "cnns.0.0.blocks.2.4.conv.bias", "cnns.0.1.weight", "cnns.0.1.bias", "cnns.1.0.blocks.0.0.conv.weight", "cnns.1.0.blocks.0.0.conv.bias", "cnns.1.0.blocks.0.2.weight", "cnns.1.0.blocks.0.2.bias", "cnns.1.0.blocks.0.4.conv.weight", "cnns.1.0.blocks.0.4.conv.bias", "cnns.1.0.blocks.1.0.conv.weight", "cnns.1.0.blocks.1.0.conv.bias", "cnns.1.0.blocks.1.2.weight", "cnns.1.0.blocks.1.2.bias", "cnns.1.0.blocks.1.4.conv.weight", "cnns.1.0.blocks.1.4.conv.bias", "cnns.1.0.blocks.2.0.conv.weight", "cnns.1.0.blocks.2.0.conv.bias", "cnns.1.0.blocks.2.2.weight", "cnns.1.0.blocks.2.2.bias", "cnns.1.0.blocks.2.4.conv.weight", "cnns.1.0.blocks.2.4.conv.bias", "cnns.1.1.weight", "cnns.1.1.bias", "cnns.2.0.blocks.0.0.conv.weight", "cnns.2.0.blocks.0.0.conv.bias", "cnns.2.0.blocks.0.2.weight", "cnns.2.0.blocks.0.2.bias", "cnns.2.0.blocks.0.4.conv.weight", "cnns.2.0.blocks.0.4.conv.bias", "cnns.2.0.blocks.1.0.conv.weight", "cnns.2.0.blocks.1.0.conv.bias", "cnns.2.0.blocks.1.2.weight", "cnns.2.0.blocks.1.2.bias", "cnns.2.0.blocks.1.4.conv.weight", "cnns.2.0.blocks.1.4.conv.bias", "cnns.2.0.blocks.2.0.conv.weight", "cnns.2.0.blocks.2.0.conv.bias", "cnns.2.0.blocks.2.2.weight", "cnns.2.0.blocks.2.2.bias", "cnns.2.0.blocks.2.4.conv.weight", "cnns.2.0.blocks.2.4.conv.bias", "cnns.2.1.weight", "cnns.2.1.bias", "cnns.3.0.blocks.0.0.conv.weight", "cnns.3.0.blocks.0.0.conv.bias", "cnns.3.0.blocks.0.2.weight", "cnns.3.0.blocks.0.2.bias", "cnns.3.0.blocks.0.4.conv.weight", "cnns.3.0.blocks.0.4.conv.bias", "cnns.3.0.blocks.1.0.conv.weight", "cnns.3.0.blocks.1.0.conv.bias", "cnns.3.0.blocks.1.2.weight", "cnns.3.0.blocks.1.2.bias", "cnns.3.0.blocks.1.4.conv.weight", "cnns.3.0.blocks.1.4.conv.bias", "cnns.3.0.blocks.2.0.conv.weight", "cnns.3.0.blocks.2.0.conv.bias", "cnns.3.0.blocks.2.2.weight", "cnns.3.0.blocks.2.2.bias", "cnns.3.0.blocks.2.4.conv.weight", "cnns.3.0.blocks.2.4.conv.bias", "cnns.3.1.weight", "cnns.3.1.bias", "cnns.4.0.blocks.0.0.conv.weight", "cnns.4.0.blocks.0.0.conv.bias", "cnns.4.0.blocks.0.2.weight", "cnns.4.0.blocks.0.2.bias", "cnns.4.0.blocks.0.4.conv.weight", "cnns.4.0.blocks.0.4.conv.bias", "cnns.4.0.blocks.1.0.conv.weight", "cnns.4.0.blocks.1.0.conv.bias", "cnns.4.0.blocks.1.2.weight", "cnns.4.0.blocks.1.2.bias", "cnns.4.0.blocks.1.4.conv.weight", "cnns.4.0.blocks.1.4.conv.bias", "cnns.4.0.blocks.2.0.conv.weight", "cnns.4.0.blocks.2.0.conv.bias", "cnns.4.0.blocks.2.2.weight", "cnns.4.0.blocks.2.2.bias", "cnns.4.0.blocks.2.4.conv.weight", "cnns.4.0.blocks.2.4.conv.bias", "cnns.4.1.weight", "cnns.4.1.bias", "cnns.5.0.blocks.0.0.conv.weight", "cnns.5.0.blocks.0.0.conv.bias", "cnns.5.0.blocks.0.2.weight", "cnns.5.0.blocks.0.2.bias", "cnns.5.0.blocks.0.4.conv.weight", "cnns.5.0.blocks.0.4.conv.bias", "cnns.5.0.blocks.1.0.conv.weight", "cnns.5.0.blocks.1.0.conv.bias", "cnns.5.0.blocks.1.2.weight", "cnns.5.0.blocks.1.2.bias", "cnns.5.0.blocks.1.4.conv.weight", "cnns.5.0.blocks.1.4.conv.bias", "cnns.5.0.blocks.2.0.conv.weight", "cnns.5.0.blocks.2.0.conv.bias", "cnns.5.0.blocks.2.2.weight", "cnns.5.0.blocks.2.2.bias", "cnns.5.0.blocks.2.4.conv.weight", "cnns.5.0.blocks.2.4.conv.bias", "cnns.5.1.weight", "cnns.5.1.bias", "projection.conv.weight", "projection.conv.bias", "ctc_linear.0.linear_layer.weight", "ctc_linear.0.linear_layer.bias", "ctc_linear.2.linear_layer.weight", "ctc_linear.2.linear_layer.bias", "asr_s2s.embedding.weight", "asr_s2s.project_to_n_symbols.weight", "asr_s2s.project_to_n_symbols.bias", "asr_s2s.attention_layer.query_layer.linear_layer.weight", "asr_s2s.attention_layer.memory_layer.linear_layer.weight", "asr_s2s.attention_layer.v.linear_layer.weight", "asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "asr_s2s.decoder_rnn.weight_ih", "asr_s2s.decoder_rnn.weight_hh", "asr_s2s.decoder_rnn.bias_ih", "asr_s2s.decoder_rnn.bias_hh", "asr_s2s.project_to_hidden.0.linear_layer.weight", "asr_s2s.project_to_hidden.0.linear_layer.bias". 
    Unexpected key(s) in state_dict: "module.to_mfcc.dct_mat", "module.init_cnn.conv.weight", "module.init_cnn.conv.bias", "module.cnns.0.0.blocks.0.0.conv.weight", "module.cnns.0.0.blocks.0.0.conv.bias", "module.cnns.0.0.blocks.0.2.weight", "module.cnns.0.0.blocks.0.2.bias", "module.cnns.0.0.blocks.0.4.conv.weight", "module.cnns.0.0.blocks.0.4.conv.bias", "module.cnns.0.0.blocks.1.0.conv.weight", "module.cnns.0.0.blocks.1.0.conv.bias", "module.cnns.0.0.blocks.1.2.weight", "module.cnns.0.0.blocks.1.2.bias", "module.cnns.0.0.blocks.1.4.conv.weight", "module.cnns.0.0.blocks.1.4.conv.bias", "module.cnns.0.0.blocks.2.0.conv.weight", "module.cnns.0.0.blocks.2.0.conv.bias", "module.cnns.0.0.blocks.2.2.weight", "module.cnns.0.0.blocks.2.2.bias", "module.cnns.0.0.blocks.2.4.conv.weight", "module.cnns.0.0.blocks.2.4.conv.bias", "module.cnns.0.1.weight", "module.cnns.0.1.bias", "module.cnns.1.0.blocks.0.0.conv.weight", "module.cnns.1.0.blocks.0.0.conv.bias", "module.cnns.1.0.blocks.0.2.weight", "module.cnns.1.0.blocks.0.2.bias", "module.cnns.1.0.blocks.0.4.conv.weight", "module.cnns.1.0.blocks.0.4.conv.bias", "module.cnns.1.0.blocks.1.0.conv.weight", "module.cnns.1.0.blocks.1.0.conv.bias", "module.cnns.1.0.blocks.1.2.weight", "module.cnns.1.0.blocks.1.2.bias", "module.cnns.1.0.blocks.1.4.conv.weight", "module.cnns.1.0.blocks.1.4.conv.bias", "module.cnns.1.0.blocks.2.0.conv.weight", "module.cnns.1.0.blocks.2.0.conv.bias", "module.cnns.1.0.blocks.2.2.weight", "module.cnns.1.0.blocks.2.2.bias", "module.cnns.1.0.blocks.2.4.conv.weight", "module.cnns.1.0.blocks.2.4.conv.bias", "module.cnns.1.1.weight", "module.cnns.1.1.bias", "module.cnns.2.0.blocks.0.0.conv.weight", "module.cnns.2.0.blocks.0.0.conv.bias", "module.cnns.2.0.blocks.0.2.weight", "module.cnns.2.0.blocks.0.2.bias", "module.cnns.2.0.blocks.0.4.conv.weight", "module.cnns.2.0.blocks.0.4.conv.bias", "module.cnns.2.0.blocks.1.0.conv.weight", "module.cnns.2.0.blocks.1.0.conv.bias", "module.cnns.2.0.blocks.1.2.weight", "module.cnns.2.0.blocks.1.2.bias", "module.cnns.2.0.blocks.1.4.conv.weight", "module.cnns.2.0.blocks.1.4.conv.bias", "module.cnns.2.0.blocks.2.0.conv.weight", "module.cnns.2.0.blocks.2.0.conv.bias", "module.cnns.2.0.blocks.2.2.weight", "module.cnns.2.0.blocks.2.2.bias", "module.cnns.2.0.blocks.2.4.conv.weight", "module.cnns.2.0.blocks.2.4.conv.bias", "module.cnns.2.1.weight", "module.cnns.2.1.bias", "module.cnns.3.0.blocks.0.0.conv.weight", "module.cnns.3.0.blocks.0.0.conv.bias", "module.cnns.3.0.blocks.0.2.weight", "module.cnns.3.0.blocks.0.2.bias", "module.cnns.3.0.blocks.0.4.conv.weight", "module.cnns.3.0.blocks.0.4.conv.bias", "module.cnns.3.0.blocks.1.0.conv.weight", "module.cnns.3.0.blocks.1.0.conv.bias", "module.cnns.3.0.blocks.1.2.weight", "module.cnns.3.0.blocks.1.2.bias", "module.cnns.3.0.blocks.1.4.conv.weight", "module.cnns.3.0.blocks.1.4.conv.bias", "module.cnns.3.0.blocks.2.0.conv.weight", "module.cnns.3.0.blocks.2.0.conv.bias", "module.cnns.3.0.blocks.2.2.weight", "module.cnns.3.0.blocks.2.2.bias", "module.cnns.3.0.blocks.2.4.conv.weight", "module.cnns.3.0.blocks.2.4.conv.bias", "module.cnns.3.1.weight", "module.cnns.3.1.bias", "module.cnns.4.0.blocks.0.0.conv.weight", "module.cnns.4.0.blocks.0.0.conv.bias", "module.cnns.4.0.blocks.0.2.weight", "module.cnns.4.0.blocks.0.2.bias", "module.cnns.4.0.blocks.0.4.conv.weight", "module.cnns.4.0.blocks.0.4.conv.bias", "module.cnns.4.0.blocks.1.0.conv.weight", "module.cnns.4.0.blocks.1.0.conv.bias", "module.cnns.4.0.blocks.1.2.weight", "module.cnns.4.0.blocks.1.2.bias", "module.cnns.4.0.blocks.1.4.conv.weight", "module.cnns.4.0.blocks.1.4.conv.bias", "module.cnns.4.0.blocks.2.0.conv.weight", "module.cnns.4.0.blocks.2.0.conv.bias", "module.cnns.4.0.blocks.2.2.weight", "module.cnns.4.0.blocks.2.2.bias", "module.cnns.4.0.blocks.2.4.conv.weight", "module.cnns.4.0.blocks.2.4.conv.bias", "module.cnns.4.1.weight", "module.cnns.4.1.bias", "module.cnns.5.0.blocks.0.0.conv.weight", "module.cnns.5.0.blocks.0.0.conv.bias", "module.cnns.5.0.blocks.0.2.weight", "module.cnns.5.0.blocks.0.2.bias", "module.cnns.5.0.blocks.0.4.conv.weight", "module.cnns.5.0.blocks.0.4.conv.bias", "module.cnns.5.0.blocks.1.0.conv.weight", "module.cnns.5.0.blocks.1.0.conv.bias", "module.cnns.5.0.blocks.1.2.weight", "module.cnns.5.0.blocks.1.2.bias", "module.cnns.5.0.blocks.1.4.conv.weight", "module.cnns.5.0.blocks.1.4.conv.bias", "module.cnns.5.0.blocks.2.0.conv.weight", "module.cnns.5.0.blocks.2.0.conv.bias", "module.cnns.5.0.blocks.2.2.weight", "module.cnns.5.0.blocks.2.2.bias", "module.cnns.5.0.blocks.2.4.conv.weight", "module.cnns.5.0.blocks.2.4.conv.bias", "module.cnns.5.1.weight", "module.cnns.5.1.bias", "module.projection.conv.weight", "module.projection.conv.bias", "module.ctc_linear.0.linear_layer.weight", "module.ctc_linear.0.linear_layer.bias", "module.ctc_linear.2.linear_layer.weight", "module.ctc_linear.2.linear_layer.bias", "module.asr_s2s.embedding.weight", "module.asr_s2s.project_to_n_symbols.weight", "module.asr_s2s.project_to_n_symbols.bias", "module.asr_s2s.attention_layer.query_layer.linear_layer.weight", "module.asr_s2s.attention_layer.memory_layer.linear_layer.weight", "module.asr_s2s.attention_layer.v.linear_layer.weight", "module.asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "module.asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "module.asr_s2s.decoder_rnn.weight_ih", "module.asr_s2s.decoder_rnn.weight_hh", "module.asr_s2s.decoder_rnn.bias_ih", "module.asr_s2s.decoder_rnn.bias_hh", "module.asr_s2s.project_to_hidden.0.linear_layer.weight", "module.asr_s2s.project_to_hidden.0.linear_layer.bias". 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[14], line 14
     12                 new_state_dict[name] = v
     13             # load params
---> 14             model[key].load_state_dict(new_state_dict, strict=False)
     15 #             except:
     16 #                 _load(params[key], model[key])
     17 _ = [model[key].eval() for key in model]

File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
   2184         error_msgs.insert(
   2185             0, 'Missing key(s) in state_dict: {}. '.format(
   2186                 ', '.join(f'"{k}"' for k in missing_keys)))
   2188 if len(error_msgs) > 0:
-> 2189     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ASRCNN:
    size mismatch for ctc_linear.2.linear_layer.weight: copying a param with shape torch.Size([178, 256]) from checkpoint, the shape in current model is torch.Size([80, 256]).
    size mismatch for ctc_linear.2.linear_layer.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
    size mismatch for asr_s2s.embedding.weight: copying a param with shape torch.Size([178, 512]) from checkpoint, the shape in current model is torch.Size([80, 256]).
    size mismatch for asr_s2s.project_to_n_symbols.weight: copying a param with shape torch.Size([178, 128]) from checkpoint, the shape in current model is torch.Size([80, 128]).
    size mismatch for asr_s2s.project_to_n_symbols.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
    size mismatch for asr_s2s.decoder_rnn.weight_ih: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([512, 384]).